| @@ -0,0 +1,22 @@ | |||
| import numpy as np | |||
| def dfill(a): | |||
| n = a.size | |||
| b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) | |||
| return np.arange(n)[b[:-1]].repeat(np.diff(b)) | |||
| def argunsort(s): | |||
| n = s.size | |||
| u = np.empty(n, dtype=np.int64) | |||
| u[s] = np.arange(n) | |||
| return u | |||
| def cumcount(a): | |||
| n = a.size | |||
| s = a.argsort(kind='mergesort') | |||
| i = argunsort(s) | |||
| b = a[s] | |||
| return (np.arange(n) - dfill(b))[i] | |||
| @@ -9,71 +9,11 @@ import torch | |||
| from .weights import init_glorot | |||
| from .normalize import _sparse_coo_tensor | |||
| import types | |||
| from .util import _sparse_diag_cat, | |||
| _cat | |||
| def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||
| if len(matrices) == 0: | |||
| raise ValueError('The list of matrices must be non-empty') | |||
| if not all(m.is_sparse for m in matrices): | |||
| raise ValueError('All matrices must be sparse') | |||
| if not all(len(m.shape) == 2 for m in matrices): | |||
| raise ValueError('All matrices must be 2D') | |||
| indices = [] | |||
| values = [] | |||
| row_offset = 0 | |||
| col_offset = 0 | |||
| for m in matrices: | |||
| ind = m._indices().clone() | |||
| ind[0] += row_offset | |||
| ind[1] += col_offset | |||
| indices.append(ind) | |||
| values.append(m._values()) | |||
| row_offset += m.shape[0] | |||
| col_offset += m.shape[1] | |||
| indices = torch.cat(indices, dim=1) | |||
| values = torch.cat(values) | |||
| return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||
| def _cat(matrices: List[torch.Tensor]): | |||
| if len(matrices) == 0: | |||
| raise ValueError('Empty list passed to _cat()') | |||
| n = sum(a.is_sparse for a in matrices) | |||
| if n != 0 and n != len(matrices): | |||
| raise ValueError('All matrices must have the same layout (dense or sparse)') | |||
| if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||
| raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||
| if not matrices[0].is_sparse: | |||
| return torch.cat(matrices) | |||
| total_rows = sum(a.shape[0] for a in matrices) | |||
| indices = [] | |||
| values = [] | |||
| row_offset = 0 | |||
| for a in matrices: | |||
| ind = a._indices().clone() | |||
| val = a._values() | |||
| ind[0] += row_offset | |||
| ind = ind.transpose(0, 1) | |||
| indices.append(ind) | |||
| values.append(val) | |||
| row_offset += a.shape[0] | |||
| indices = torch.cat(indices).transpose(0, 1) | |||
| values = torch.cat(values) | |||
| res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
| return res | |||
| class FastGraphConv(torch.nn.Module): | |||
| @@ -12,6 +12,8 @@ from typing import List, \ | |||
| Tuple | |||
| from .data import Data, \ | |||
| EdgeType | |||
| from .cumcount import cumcount | |||
| import time | |||
| def fixed_unigram_candidate_sampler( | |||
| @@ -69,14 +71,62 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ | |||
| return edges_pos, degrees | |||
| def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| indices = adj_mat.indices() | |||
| row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | |||
| #print('indices[0]:', indices[0], count[indices[0]]) | |||
| row_count = row_count.index_add(0, indices[0], | |||
| torch.ones(indices.shape[1], dtype=torch.long)) | |||
| #print('count:', count) | |||
| max_true_classes = torch.max(row_count).item() | |||
| #print('max_true_classes:', max_true_classes) | |||
| true_classes = torch.full((adj_mat.shape[0], max_true_classes), | |||
| -1, dtype=torch.long) | |||
| # inv = torch.unique(indices[0], return_inverse=True) | |||
| # indices = indices.copy() | |||
| # true_classes[indices[0], 0] = indices[1] | |||
| t = time.time() | |||
| cc = cumcount(indices[0].cpu().numpy()) | |||
| print('cumcount() took:', time.time() - t) | |||
| cc = torch.tensor(cc) | |||
| t = time.time() | |||
| true_classes[indices[0], cc] = indices[1] | |||
| print('assignment took:', time.time() - t) | |||
| ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | |||
| for i in range(indices.shape[1]): | |||
| # print('looping...') | |||
| row = indices[0, i] | |||
| col = indices[1, i] | |||
| #print('row:', row, 'col:', col, 'count[row]:', count[row]) | |||
| true_classes[row, count[row]] = col | |||
| count[row] += 1 ''' | |||
| t = time.time() | |||
| true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) | |||
| print('repeat_interleave() took:', time.time() - t) | |||
| return true_classes | |||
| def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| if not isinstance(adj_mat, torch.Tensor): | |||
| raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | |||
| edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||
| true_classes = get_true_classes(adj_mat) | |||
| # true_classes = edges_pos[:, 1].view(-1, 1) | |||
| # print('true_classes:', true_classes) | |||
| neg_neighbors = fixed_unigram_candidate_sampler( | |||
| edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||
| true_classes, degrees, 0.75).to(adj_mat.device) | |||
| print('neg_neighbors:', neg_neighbors) | |||
| edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1), | |||
| neg_neighbors.view(-1, 1) ], 1) | |||
| @@ -1,8 +1,41 @@ | |||
| from triacontagon.data import Data | |||
| from triacontagon.sampling import negative_sample_adj_mat, \ | |||
| from triacontagon.sampling import get_true_classes, \ | |||
| negative_sample_adj_mat, \ | |||
| negative_sample_data | |||
| from triacontagon.decode import dedicom_decoder | |||
| import torch | |||
| import time | |||
| def test_get_true_classes_01(): | |||
| adj_mat = torch.tensor([ | |||
| [0, 1, 0, 1, 0], | |||
| [0, 0, 0, 0, 1], | |||
| [1, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 1], | |||
| [0, 1, 0, 0, 0] | |||
| ], dtype=torch.float).to_sparse() | |||
| true_classes = get_true_classes(adj_mat) | |||
| print('true_classes:', true_classes) | |||
| assert torch.all(true_classes == torch.tensor([ | |||
| [1, 3], | |||
| [4, -1], | |||
| [0, 1], | |||
| [2, 4], | |||
| [1, -1] | |||
| ])) | |||
| def test_get_true_classes_02(): | |||
| adj_mat = torch.rand(2000, 2000).round().to_sparse() | |||
| t = time.time() | |||
| true_classes = get_true_classes(adj_mat) | |||
| print('Elapsed:', time.time() - t) | |||
| print('true_classes.shape:', true_classes.shape) | |||
| def test_negative_sample_adj_mat_01(): | |||
| @@ -16,7 +49,7 @@ def test_negative_sample_adj_mat_01(): | |||
| print('adj_mat:', adj_mat) | |||
| adj_mat_neg = negative_sample_adj_mat(adj_mat) | |||
| adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse()) | |||
| print('adj_mat_neg:', adj_mat_neg.to_dense()) | |||