diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index cc30402..13ef5cf 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -135,7 +135,10 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: print('neg_neighbors:', neg_neighbors) - edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1), + pos_vertices = torch.repeat_interleave(torch.arange(len(adj_mat)), + row_count) + + edges_neg = torch.cat([ pos_vertices.view(-1, 1), neg_neighbors.view(-1, 1) ], 1) adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),