IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
ソースを参照

Make negative sampling more correct and more efficient at the same time.

master
Stanislaw Adaszewski 4年前
コミット
56fdd9ffeb
2個のファイルの変更30行の追加28行の削除
  1. +25
    -25
      src/triacontagon/sampling.py
  2. +5
    -3
      tests/triacontagon/test_sampling.py

+ 25
- 25
src/triacontagon/sampling.py ファイルの表示

@@ -17,41 +17,41 @@ import time
def fixed_unigram_candidate_sampler(
true_classes: Union[np.array, torch.Tensor],
unigrams: List[Union[int, float]],
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
distortion: float = 1.):
if isinstance(true_classes, torch.Tensor):
true_classes = true_classes.detach().cpu().numpy()
if isinstance(unigrams, torch.Tensor):
unigrams = unigrams.detach().cpu().numpy()
if len(true_classes.shape) != 2:
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
num_samples = true_classes.shape[0]
unigrams = np.array(unigrams)
if len(num_repeats.shape) != 1:
raise ValueError('num_repeats must be 1D')
num_rows = true_classes.shape[0]
# unigrams = np.array(unigrams)
if distortion != 1.:
unigrams = unigrams.astype(np.float64) ** distortion
unigrams = unigrams.to(torch.float64) ** distortion
# print('unigrams:', unigrams)
indices = np.arange(num_samples)
result = np.zeros(num_samples, dtype=np.int64)
indices = torch.arange(num_rows)
indices = torch.repeat_interleave(indices, num_repeats)
num_samples = len(indices)
result = torch.zeros(num_samples, dtype=torch.long)
while len(indices) > 0:
# print('len(indices):', len(indices))
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
candidates = np.array(list(sampler))
candidates = np.reshape(candidates, (len(indices), 1))
candidates = torch.tensor(list(sampler))
candidates = candidates.view(len(indices), 1)
# print('candidates:', candidates)
# print('true_classes:', true_classes[indices, :])
result[indices] = candidates.T
result[indices] = candidates.transpose(0, 1)
# print('result:', result)
mask = (candidates == true_classes[indices, :])
mask = mask.sum(1).astype(np.bool)
mask = mask.sum(1).to(torch.bool)
# print('mask:', mask)
indices = indices[mask]
# result[indices] = 0
return torch.tensor(result)
return result
def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
@@ -71,7 +71,7 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
return edges_pos, degrees
def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor:
def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
indices = adj_mat.indices()
row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
#print('indices[0]:', indices[0], count[indices[0]])
@@ -105,11 +105,11 @@ def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor:
true_classes[row, count[row]] = col
count[row] += 1 '''
t = time.time()
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
print('repeat_interleave() took:', time.time() - t)
# t = time.time()
# true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
# print('repeat_interleave() took:', time.time() - t)
return true_classes
return true_classes, row_count
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
@@ -118,12 +118,12 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
edges_pos, degrees = get_edges_and_degrees(adj_mat)
true_classes = get_true_classes(adj_mat)
true_classes, row_count = get_true_classes(adj_mat)
# true_classes = edges_pos[:, 1].view(-1, 1)
# print('true_classes:', true_classes)
neg_neighbors = fixed_unigram_candidate_sampler(
true_classes, degrees, 0.75).to(adj_mat.device)
true_classes, row_count, degrees, 0.75).to(adj_mat.device)
print('neg_neighbors:', neg_neighbors)


+ 5
- 3
tests/triacontagon/test_sampling.py ファイルの表示

@@ -16,9 +16,11 @@ def test_get_true_classes_01():
[0, 1, 0, 0, 0]
], dtype=torch.float).to_sparse()
true_classes = get_true_classes(adj_mat)
true_classes, row_count = get_true_classes(adj_mat)
print('true_classes:', true_classes)
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
assert torch.all(true_classes == torch.tensor([
[1, 3],
[1, 3],
@@ -32,10 +34,10 @@ def test_get_true_classes_01():
def test_get_true_classes_02():
adj_mat = (torch.rand(2000, 2000) < 0.1).to_sparse()
adj_mat = torch.rand(2000, 2000).round().to_sparse()
t = time.time()
true_classes = get_true_classes(adj_mat)
true_classes, row_count = get_true_classes(adj_mat)
print('Elapsed:', time.time() - t)
print('true_classes.shape:', true_classes.shape)


読み込み中…
キャンセル
保存