|
|
@@ -17,41 +17,41 @@ import time |
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
|
true_classes: Union[np.array, torch.Tensor],
|
|
|
|
unigrams: List[Union[int, float]],
|
|
|
|
true_classes: torch.Tensor,
|
|
|
|
num_repeats: torch.Tensor,
|
|
|
|
unigrams: torch.Tensor,
|
|
|
|
distortion: float = 1.):
|
|
|
|
|
|
|
|
if isinstance(true_classes, torch.Tensor):
|
|
|
|
true_classes = true_classes.detach().cpu().numpy()
|
|
|
|
|
|
|
|
if isinstance(unigrams, torch.Tensor):
|
|
|
|
unigrams = unigrams.detach().cpu().numpy()
|
|
|
|
|
|
|
|
if len(true_classes.shape) != 2:
|
|
|
|
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
|
|
|
|
|
|
|
|
num_samples = true_classes.shape[0]
|
|
|
|
unigrams = np.array(unigrams)
|
|
|
|
if len(num_repeats.shape) != 1:
|
|
|
|
raise ValueError('num_repeats must be 1D')
|
|
|
|
|
|
|
|
num_rows = true_classes.shape[0]
|
|
|
|
# unigrams = np.array(unigrams)
|
|
|
|
if distortion != 1.:
|
|
|
|
unigrams = unigrams.astype(np.float64) ** distortion
|
|
|
|
unigrams = unigrams.to(torch.float64) ** distortion
|
|
|
|
# print('unigrams:', unigrams)
|
|
|
|
indices = np.arange(num_samples)
|
|
|
|
result = np.zeros(num_samples, dtype=np.int64)
|
|
|
|
indices = torch.arange(num_rows)
|
|
|
|
indices = torch.repeat_interleave(indices, num_repeats)
|
|
|
|
num_samples = len(indices)
|
|
|
|
result = torch.zeros(num_samples, dtype=torch.long)
|
|
|
|
while len(indices) > 0:
|
|
|
|
# print('len(indices):', len(indices))
|
|
|
|
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
|
|
|
|
candidates = np.array(list(sampler))
|
|
|
|
candidates = np.reshape(candidates, (len(indices), 1))
|
|
|
|
candidates = torch.tensor(list(sampler))
|
|
|
|
candidates = candidates.view(len(indices), 1)
|
|
|
|
# print('candidates:', candidates)
|
|
|
|
# print('true_classes:', true_classes[indices, :])
|
|
|
|
result[indices] = candidates.T
|
|
|
|
result[indices] = candidates.transpose(0, 1)
|
|
|
|
# print('result:', result)
|
|
|
|
mask = (candidates == true_classes[indices, :])
|
|
|
|
mask = mask.sum(1).astype(np.bool)
|
|
|
|
mask = mask.sum(1).to(torch.bool)
|
|
|
|
# print('mask:', mask)
|
|
|
|
indices = indices[mask]
|
|
|
|
# result[indices] = 0
|
|
|
|
return torch.tensor(result)
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
|
|
|
@@ -71,7 +71,7 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ |
|
|
|
return edges_pos, degrees
|
|
|
|
|
|
|
|
|
|
|
|
def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor:
|
|
|
|
def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
indices = adj_mat.indices()
|
|
|
|
row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
|
|
|
|
#print('indices[0]:', indices[0], count[indices[0]])
|
|
|
@@ -105,11 +105,11 @@ def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: |
|
|
|
true_classes[row, count[row]] = col
|
|
|
|
count[row] += 1 '''
|
|
|
|
|
|
|
|
t = time.time()
|
|
|
|
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
|
|
|
|
print('repeat_interleave() took:', time.time() - t)
|
|
|
|
# t = time.time()
|
|
|
|
# true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
|
|
|
|
# print('repeat_interleave() took:', time.time() - t)
|
|
|
|
|
|
|
|
return true_classes
|
|
|
|
return true_classes, row_count
|
|
|
|
|
|
|
|
|
|
|
|
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
|
|
|
@@ -118,12 +118,12 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
edges_pos, degrees = get_edges_and_degrees(adj_mat)
|
|
|
|
|
|
|
|
true_classes = get_true_classes(adj_mat)
|
|
|
|
true_classes, row_count = get_true_classes(adj_mat)
|
|
|
|
# true_classes = edges_pos[:, 1].view(-1, 1)
|
|
|
|
# print('true_classes:', true_classes)
|
|
|
|
|
|
|
|
neg_neighbors = fixed_unigram_candidate_sampler(
|
|
|
|
true_classes, degrees, 0.75).to(adj_mat.device)
|
|
|
|
true_classes, row_count, degrees, 0.75).to(adj_mat.device)
|
|
|
|
|
|
|
|
print('neg_neighbors:', neg_neighbors)
|
|
|
|
|
|
|
|