|
|
@@ -8,7 +8,10 @@ import numpy as np |
|
|
|
import torch
|
|
|
|
import torch.utils.data
|
|
|
|
from typing import List, \
|
|
|
|
Union
|
|
|
|
Union, \
|
|
|
|
Tuple
|
|
|
|
from .data import Data, \
|
|
|
|
EdgeType
|
|
|
|
|
|
|
|
|
|
|
|
def fixed_unigram_candidate_sampler(
|
|
|
@@ -24,7 +27,7 @@ def fixed_unigram_candidate_sampler( |
|
|
|
|
|
|
|
if len(true_classes.shape) != 2:
|
|
|
|
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
|
|
|
|
|
|
|
|
|
|
|
|
num_samples = true_classes.shape[0]
|
|
|
|
unigrams = np.array(unigrams)
|
|
|
|
if distortion != 1.:
|
|
|
@@ -40,8 +43,74 @@ def fixed_unigram_candidate_sampler( |
|
|
|
# print('candidates:', candidates)
|
|
|
|
# print('true_classes:', true_classes[indices, :])
|
|
|
|
result[indices] = candidates.T
|
|
|
|
# print('result:', result)
|
|
|
|
mask = (candidates == true_classes[indices, :])
|
|
|
|
mask = mask.sum(1).astype(np.bool)
|
|
|
|
# print('mask:', mask)
|
|
|
|
indices = indices[mask]
|
|
|
|
# result[indices] = 0
|
|
|
|
return torch.tensor(result)
|
|
|
|
|
|
|
|
|
|
|
|
def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
|
|
|
|
Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
if adj_mat.is_sparse:
|
|
|
|
adj_mat = adj_mat.coalesce()
|
|
|
|
degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
|
|
|
|
device=adj_mat.device)
|
|
|
|
degrees = degrees.index_add(0, adj_mat.indices()[1],
|
|
|
|
torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
|
|
|
|
device=adj_mat.device))
|
|
|
|
edges_pos = adj_mat.indices().transpose(0, 1)
|
|
|
|
else:
|
|
|
|
degrees = adj_mat.sum(0)
|
|
|
|
edges_pos = torch.nonzero(adj_mat, as_tuple=False)
|
|
|
|
return edges_pos, degrees
|
|
|
|
|
|
|
|
|
|
|
|
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
|
|
|
|
if not isinstance(adj_mat, torch.Tensor):
|
|
|
|
raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
|
|
|
|
|
|
|
|
edges_pos, degrees = get_edges_and_degrees(adj_mat)
|
|
|
|
|
|
|
|
neg_neighbors = fixed_unigram_candidate_sampler(
|
|
|
|
edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device)
|
|
|
|
edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1),
|
|
|
|
neg_neighbors.view(-1, 1) ], 1)
|
|
|
|
|
|
|
|
adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
|
|
|
|
values=torch.ones(len(edges_neg)), size=adj_mat.shape,
|
|
|
|
dtype=adj_mat.dtype, device=adj_mat.device)
|
|
|
|
|
|
|
|
adj_mat_neg = adj_mat_neg.coalesce()
|
|
|
|
indices = adj_mat_neg.indices()
|
|
|
|
adj_mat_neg = torch.sparse_coo_tensor(indices,
|
|
|
|
torch.ones(indices.shape[1]), adj_mat.shape,
|
|
|
|
dtype=adj_mat.dtype, device=adj_mat.device)
|
|
|
|
|
|
|
|
adj_mat_neg = adj_mat_neg.coalesce()
|
|
|
|
|
|
|
|
return adj_mat_neg
|
|
|
|
|
|
|
|
|
|
|
|
def negative_sample_data(data: Data) -> Data:
|
|
|
|
new_edge_types = {}
|
|
|
|
res = Data()
|
|
|
|
for vt in data.vertex_types:
|
|
|
|
res.add_vertex_type(vt.name, vt.count)
|
|
|
|
for key, et in data.edge_types.items():
|
|
|
|
adjacency_matrices_neg = []
|
|
|
|
for adj_mat in et.adjacency_matrices:
|
|
|
|
adj_mat_neg = negative_sample_adj_mat(adj_mat)
|
|
|
|
adjacency_matrices_neg.append(adj_mat_neg)
|
|
|
|
res.add_edge_type(et.name,
|
|
|
|
et.vertex_type_row, et.vertex_type_column,
|
|
|
|
adjacency_matrices_neg, et.decoder_factory)
|
|
|
|
#new_et = EdgeType(et.name, et.vertex_type_row,
|
|
|
|
# et.vertex_type_column, adjacency_matrices_neg,
|
|
|
|
# et.decoder_factory, et.total_connectivity)
|
|
|
|
#new_edge_types[key] = new_et
|
|
|
|
#res = Data(data.vertex_types, new_edge_types)
|
|
|
|
return res
|