@@ -0,0 +1,22 @@ | |||
import numpy as np | |||
def dfill(a): | |||
n = a.size | |||
b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||
return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||
def argunsort(s): | |||
n = s.size | |||
u = np.empty(n, dtype=np.int64) | |||
u[s] = np.arange(n) | |||
return u | |||
def cumcount(a): | |||
n = a.size | |||
s = a.argsort(kind='mergesort') | |||
i = argunsort(s) | |||
b = a[s] | |||
return (np.arange(n) - dfill(b))[i] |
@@ -9,71 +9,11 @@ import torch | |||
from .weights import init_glorot | |||
from .normalize import _sparse_coo_tensor | |||
import types | |||
from .util import _sparse_diag_cat, | |||
_cat | |||
def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||
if len(matrices) == 0: | |||
raise ValueError('The list of matrices must be non-empty') | |||
if not all(m.is_sparse for m in matrices): | |||
raise ValueError('All matrices must be sparse') | |||
if not all(len(m.shape) == 2 for m in matrices): | |||
raise ValueError('All matrices must be 2D') | |||
indices = [] | |||
values = [] | |||
row_offset = 0 | |||
col_offset = 0 | |||
for m in matrices: | |||
ind = m._indices().clone() | |||
ind[0] += row_offset | |||
ind[1] += col_offset | |||
indices.append(ind) | |||
values.append(m._values()) | |||
row_offset += m.shape[0] | |||
col_offset += m.shape[1] | |||
indices = torch.cat(indices, dim=1) | |||
values = torch.cat(values) | |||
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||
def _cat(matrices: List[torch.Tensor]): | |||
if len(matrices) == 0: | |||
raise ValueError('Empty list passed to _cat()') | |||
n = sum(a.is_sparse for a in matrices) | |||
if n != 0 and n != len(matrices): | |||
raise ValueError('All matrices must have the same layout (dense or sparse)') | |||
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||
raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||
if not matrices[0].is_sparse: | |||
return torch.cat(matrices) | |||
total_rows = sum(a.shape[0] for a in matrices) | |||
indices = [] | |||
values = [] | |||
row_offset = 0 | |||
for a in matrices: | |||
ind = a._indices().clone() | |||
val = a._values() | |||
ind[0] += row_offset | |||
ind = ind.transpose(0, 1) | |||
indices.append(ind) | |||
values.append(val) | |||
row_offset += a.shape[0] | |||
indices = torch.cat(indices).transpose(0, 1) | |||
values = torch.cat(values) | |||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
return res | |||
class FastGraphConv(torch.nn.Module): | |||
@@ -12,6 +12,8 @@ from typing import List, \ | |||
Tuple | |||
from .data import Data, \ | |||
EdgeType | |||
from .cumcount import cumcount | |||
import time | |||
def fixed_unigram_candidate_sampler( | |||
@@ -69,14 +71,62 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ | |||
return edges_pos, degrees | |||
def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: | |||
indices = adj_mat.indices() | |||
row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | |||
#print('indices[0]:', indices[0], count[indices[0]]) | |||
row_count = row_count.index_add(0, indices[0], | |||
torch.ones(indices.shape[1], dtype=torch.long)) | |||
#print('count:', count) | |||
max_true_classes = torch.max(row_count).item() | |||
#print('max_true_classes:', max_true_classes) | |||
true_classes = torch.full((adj_mat.shape[0], max_true_classes), | |||
-1, dtype=torch.long) | |||
# inv = torch.unique(indices[0], return_inverse=True) | |||
# indices = indices.copy() | |||
# true_classes[indices[0], 0] = indices[1] | |||
t = time.time() | |||
cc = cumcount(indices[0].cpu().numpy()) | |||
print('cumcount() took:', time.time() - t) | |||
cc = torch.tensor(cc) | |||
t = time.time() | |||
true_classes[indices[0], cc] = indices[1] | |||
print('assignment took:', time.time() - t) | |||
''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | |||
for i in range(indices.shape[1]): | |||
# print('looping...') | |||
row = indices[0, i] | |||
col = indices[1, i] | |||
#print('row:', row, 'col:', col, 'count[row]:', count[row]) | |||
true_classes[row, count[row]] = col | |||
count[row] += 1 ''' | |||
t = time.time() | |||
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) | |||
print('repeat_interleave() took:', time.time() - t) | |||
return true_classes | |||
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||
if not isinstance(adj_mat, torch.Tensor): | |||
raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | |||
edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||
true_classes = get_true_classes(adj_mat) | |||
# true_classes = edges_pos[:, 1].view(-1, 1) | |||
# print('true_classes:', true_classes) | |||
neg_neighbors = fixed_unigram_candidate_sampler( | |||
edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||
true_classes, degrees, 0.75).to(adj_mat.device) | |||
print('neg_neighbors:', neg_neighbors) | |||
edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1), | |||
neg_neighbors.view(-1, 1) ], 1) | |||
@@ -1,8 +1,41 @@ | |||
from triacontagon.data import Data | |||
from triacontagon.sampling import negative_sample_adj_mat, \ | |||
from triacontagon.sampling import get_true_classes, \ | |||
negative_sample_adj_mat, \ | |||
negative_sample_data | |||
from triacontagon.decode import dedicom_decoder | |||
import torch | |||
import time | |||
def test_get_true_classes_01(): | |||
adj_mat = torch.tensor([ | |||
[0, 1, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[1, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 1], | |||
[0, 1, 0, 0, 0] | |||
], dtype=torch.float).to_sparse() | |||
true_classes = get_true_classes(adj_mat) | |||
print('true_classes:', true_classes) | |||
assert torch.all(true_classes == torch.tensor([ | |||
[1, 3], | |||
[4, -1], | |||
[0, 1], | |||
[2, 4], | |||
[1, -1] | |||
])) | |||
def test_get_true_classes_02(): | |||
adj_mat = torch.rand(2000, 2000).round().to_sparse() | |||
t = time.time() | |||
true_classes = get_true_classes(adj_mat) | |||
print('Elapsed:', time.time() - t) | |||
print('true_classes.shape:', true_classes.shape) | |||
def test_negative_sample_adj_mat_01(): | |||
@@ -16,7 +49,7 @@ def test_negative_sample_adj_mat_01(): | |||
print('adj_mat:', adj_mat) | |||
adj_mat_neg = negative_sample_adj_mat(adj_mat) | |||
adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse()) | |||
print('adj_mat_neg:', adj_mat_neg.to_dense()) | |||