@@ -0,0 +1,22 @@ | |||||
import numpy as np | |||||
def dfill(a): | |||||
n = a.size | |||||
b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||||
return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||||
def argunsort(s): | |||||
n = s.size | |||||
u = np.empty(n, dtype=np.int64) | |||||
u[s] = np.arange(n) | |||||
return u | |||||
def cumcount(a): | |||||
n = a.size | |||||
s = a.argsort(kind='mergesort') | |||||
i = argunsort(s) | |||||
b = a[s] | |||||
return (np.arange(n) - dfill(b))[i] |
@@ -9,71 +9,11 @@ import torch | |||||
from .weights import init_glorot | from .weights import init_glorot | ||||
from .normalize import _sparse_coo_tensor | from .normalize import _sparse_coo_tensor | ||||
import types | import types | ||||
from .util import _sparse_diag_cat, | |||||
_cat | |||||
def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||||
if len(matrices) == 0: | |||||
raise ValueError('The list of matrices must be non-empty') | |||||
if not all(m.is_sparse for m in matrices): | |||||
raise ValueError('All matrices must be sparse') | |||||
if not all(len(m.shape) == 2 for m in matrices): | |||||
raise ValueError('All matrices must be 2D') | |||||
indices = [] | |||||
values = [] | |||||
row_offset = 0 | |||||
col_offset = 0 | |||||
for m in matrices: | |||||
ind = m._indices().clone() | |||||
ind[0] += row_offset | |||||
ind[1] += col_offset | |||||
indices.append(ind) | |||||
values.append(m._values()) | |||||
row_offset += m.shape[0] | |||||
col_offset += m.shape[1] | |||||
indices = torch.cat(indices, dim=1) | |||||
values = torch.cat(values) | |||||
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||||
def _cat(matrices: List[torch.Tensor]): | |||||
if len(matrices) == 0: | |||||
raise ValueError('Empty list passed to _cat()') | |||||
n = sum(a.is_sparse for a in matrices) | |||||
if n != 0 and n != len(matrices): | |||||
raise ValueError('All matrices must have the same layout (dense or sparse)') | |||||
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||||
raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||||
if not matrices[0].is_sparse: | |||||
return torch.cat(matrices) | |||||
total_rows = sum(a.shape[0] for a in matrices) | |||||
indices = [] | |||||
values = [] | |||||
row_offset = 0 | |||||
for a in matrices: | |||||
ind = a._indices().clone() | |||||
val = a._values() | |||||
ind[0] += row_offset | |||||
ind = ind.transpose(0, 1) | |||||
indices.append(ind) | |||||
values.append(val) | |||||
row_offset += a.shape[0] | |||||
indices = torch.cat(indices).transpose(0, 1) | |||||
values = torch.cat(values) | |||||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||||
return res | |||||
class FastGraphConv(torch.nn.Module): | class FastGraphConv(torch.nn.Module): | ||||
@@ -12,6 +12,8 @@ from typing import List, \ | |||||
Tuple | Tuple | ||||
from .data import Data, \ | from .data import Data, \ | ||||
EdgeType | EdgeType | ||||
from .cumcount import cumcount | |||||
import time | |||||
def fixed_unigram_candidate_sampler( | def fixed_unigram_candidate_sampler( | ||||
@@ -69,14 +71,62 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ | |||||
return edges_pos, degrees | return edges_pos, degrees | ||||
def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
indices = adj_mat.indices() | |||||
row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | |||||
#print('indices[0]:', indices[0], count[indices[0]]) | |||||
row_count = row_count.index_add(0, indices[0], | |||||
torch.ones(indices.shape[1], dtype=torch.long)) | |||||
#print('count:', count) | |||||
max_true_classes = torch.max(row_count).item() | |||||
#print('max_true_classes:', max_true_classes) | |||||
true_classes = torch.full((adj_mat.shape[0], max_true_classes), | |||||
-1, dtype=torch.long) | |||||
# inv = torch.unique(indices[0], return_inverse=True) | |||||
# indices = indices.copy() | |||||
# true_classes[indices[0], 0] = indices[1] | |||||
t = time.time() | |||||
cc = cumcount(indices[0].cpu().numpy()) | |||||
print('cumcount() took:', time.time() - t) | |||||
cc = torch.tensor(cc) | |||||
t = time.time() | |||||
true_classes[indices[0], cc] = indices[1] | |||||
print('assignment took:', time.time() - t) | |||||
''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | |||||
for i in range(indices.shape[1]): | |||||
# print('looping...') | |||||
row = indices[0, i] | |||||
col = indices[1, i] | |||||
#print('row:', row, 'col:', col, 'count[row]:', count[row]) | |||||
true_classes[row, count[row]] = col | |||||
count[row] += 1 ''' | |||||
t = time.time() | |||||
true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) | |||||
print('repeat_interleave() took:', time.time() - t) | |||||
return true_classes | |||||
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | ||||
if not isinstance(adj_mat, torch.Tensor): | if not isinstance(adj_mat, torch.Tensor): | ||||
raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | ||||
edges_pos, degrees = get_edges_and_degrees(adj_mat) | edges_pos, degrees = get_edges_and_degrees(adj_mat) | ||||
true_classes = get_true_classes(adj_mat) | |||||
# true_classes = edges_pos[:, 1].view(-1, 1) | |||||
# print('true_classes:', true_classes) | |||||
neg_neighbors = fixed_unigram_candidate_sampler( | neg_neighbors = fixed_unigram_candidate_sampler( | ||||
edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||||
true_classes, degrees, 0.75).to(adj_mat.device) | |||||
print('neg_neighbors:', neg_neighbors) | |||||
edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1), | edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1), | ||||
neg_neighbors.view(-1, 1) ], 1) | neg_neighbors.view(-1, 1) ], 1) | ||||
@@ -1,8 +1,41 @@ | |||||
from triacontagon.data import Data | from triacontagon.data import Data | ||||
from triacontagon.sampling import negative_sample_adj_mat, \ | |||||
from triacontagon.sampling import get_true_classes, \ | |||||
negative_sample_adj_mat, \ | |||||
negative_sample_data | negative_sample_data | ||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
import torch | import torch | ||||
import time | |||||
def test_get_true_classes_01(): | |||||
adj_mat = torch.tensor([ | |||||
[0, 1, 0, 1, 0], | |||||
[0, 0, 0, 0, 1], | |||||
[1, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 1], | |||||
[0, 1, 0, 0, 0] | |||||
], dtype=torch.float).to_sparse() | |||||
true_classes = get_true_classes(adj_mat) | |||||
print('true_classes:', true_classes) | |||||
assert torch.all(true_classes == torch.tensor([ | |||||
[1, 3], | |||||
[4, -1], | |||||
[0, 1], | |||||
[2, 4], | |||||
[1, -1] | |||||
])) | |||||
def test_get_true_classes_02(): | |||||
adj_mat = torch.rand(2000, 2000).round().to_sparse() | |||||
t = time.time() | |||||
true_classes = get_true_classes(adj_mat) | |||||
print('Elapsed:', time.time() - t) | |||||
print('true_classes.shape:', true_classes.shape) | |||||
def test_negative_sample_adj_mat_01(): | def test_negative_sample_adj_mat_01(): | ||||
@@ -16,7 +49,7 @@ def test_negative_sample_adj_mat_01(): | |||||
print('adj_mat:', adj_mat) | print('adj_mat:', adj_mat) | ||||
adj_mat_neg = negative_sample_adj_mat(adj_mat) | |||||
adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse()) | |||||
print('adj_mat_neg:', adj_mat_neg.to_dense()) | print('adj_mat_neg:', adj_mat_neg.to_dense()) | ||||