IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Selaa lähdekoodia

Make negative sampling more correct and more efficient at the same time.

master
Stanislaw Adaszewski 4 vuotta sitten
vanhempi
commit
56fdd9ffeb
2 muutettua tiedostoa jossa 30 lisäystä ja 28 poistoa
  1. +25
    -25
      src/triacontagon/sampling.py
  2. +5
    -3
      tests/triacontagon/test_sampling.py

+ 25
- 25
src/triacontagon/sampling.py Näytä tiedosto

@@ -17,41 +17,41 @@ import time
def fixed_unigram_candidate_sampler(
true_classes: Union[np.array, torch.Tensor],
unigrams: List[Union[int, float]],
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
distortion: float = 1.):
if isinstance(true_classes, torch.Tensor):
true_classes = true_classes.detach().cpu().numpy()
if isinstance(unigrams, torch.Tensor):
unigrams = unigrams.detach().cpu().numpy()
if len(true_classes.shape) != 2:
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
num_samples = true_classes.shape[0]
unigrams = np.array(unigrams)
if len(num_repeats.shape) != 1:
raise ValueError('num_repeats must be 1D')
num_rows = true_classes.shape[0]
# unigrams = np.array(unigrams)
if distortion != 1.:
unigrams = unigrams.astype(np.float64) ** distortion
unigrams = unigrams.to(torch.float64) ** distortion
# print('unigrams:', unigrams)
indices = np.arange(num_samples)
result = np.zeros(num_samples, dtype=np.int64)
indices = torch.arange(num_rows)
indices = torch.repeat_interleave(indices, num_repeats)
num_samples = len(indices)
result = torch.zeros(num_samples, dtype=torch.long)
while len(indices) > 0:
# print('len(indices):', len(indices))
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
candidates = np.array(list(sampler))
candidates = np.reshape(candidates, (len(indices), 1))
candidates = torch.tensor(list(sampler))
candidates = candidates.view(len(indices), 1)
# print('candidates:', candidates)
# print('true_classes:', true_classes[indices, :])
result[indices] = candidates.T
result[indices] = candidates.transpose(0, 1)
# print('result:', result)
mask = (candidates == true_classes[indices, :])
mask = mask.sum(1).astype(np.bool)
mask = mask.sum(1).to(torch.bool)
# print('mask:', mask)
indices = indices[mask]
# result[indices] = 0
return torch.tensor(result)
return result
def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
@@ -71,7 +71,7 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
return edges_pos, degrees
def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor:
def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
indices = adj_mat.indices()
row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
#print('indices[0]:', indices[0], count[indices[0]])
@@ -105,11 +105,11 @@ def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor:
true_classes[row, count[row]] = col
count[row] += 1 '''
t = time.time()
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
print('repeat_interleave() took:', time.time() - t)
# t = time.time()
# true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
# print('repeat_interleave() took:', time.time() - t)
return true_classes
return true_classes, row_count
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
@@ -118,12 +118,12 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
edges_pos, degrees = get_edges_and_degrees(adj_mat)
true_classes = get_true_classes(adj_mat)
true_classes, row_count = get_true_classes(adj_mat)
# true_classes = edges_pos[:, 1].view(-1, 1)
# print('true_classes:', true_classes)
neg_neighbors = fixed_unigram_candidate_sampler(
true_classes, degrees, 0.75).to(adj_mat.device)
true_classes, row_count, degrees, 0.75).to(adj_mat.device)
print('neg_neighbors:', neg_neighbors)


+ 5
- 3
tests/triacontagon/test_sampling.py Näytä tiedosto

@@ -16,9 +16,11 @@ def test_get_true_classes_01():
[0, 1, 0, 0, 0]
], dtype=torch.float).to_sparse()
true_classes = get_true_classes(adj_mat)
true_classes, row_count = get_true_classes(adj_mat)
print('true_classes:', true_classes)
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
assert torch.all(true_classes == torch.tensor([
[1, 3],
[1, 3],
@@ -32,10 +34,10 @@ def test_get_true_classes_01():
def test_get_true_classes_02():
adj_mat = (torch.rand(2000, 2000) < 0.1).to_sparse()
adj_mat = torch.rand(2000, 2000).round().to_sparse()
t = time.time()
true_classes = get_true_classes(adj_mat)
true_classes, row_count = get_true_classes(adj_mat)
print('Elapsed:', time.time() - t)
print('true_classes.shape:', true_classes.shape)


Loading…
Peruuta
Tallenna