From 4ed3715b832cb09d320613effd9784a9263f5177 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 7 Aug 2020 16:28:05 +0200 Subject: [PATCH] Add cumcount(). --- src/triacontagon/cumcount.py | 22 +++++++++ src/triacontagon/deprecated/fastconv.py | 64 +------------------------ src/triacontagon/sampling.py | 52 +++++++++++++++++++- tests/triacontagon/test_sampling.py | 37 +++++++++++++- 4 files changed, 110 insertions(+), 65 deletions(-) create mode 100644 src/triacontagon/cumcount.py diff --git a/src/triacontagon/cumcount.py b/src/triacontagon/cumcount.py new file mode 100644 index 0000000..ba1f23f --- /dev/null +++ b/src/triacontagon/cumcount.py @@ -0,0 +1,22 @@ +import numpy as np + + +def dfill(a): + n = a.size + b = np.concatenate([[0], np.where(a[:-1] != a[1:])[0] + 1, [n]]) + return np.arange(n)[b[:-1]].repeat(np.diff(b)) + + +def argunsort(s): + n = s.size + u = np.empty(n, dtype=np.int64) + u[s] = np.arange(n) + return u + + +def cumcount(a): + n = a.size + s = a.argsort(kind='mergesort') + i = argunsort(s) + b = a[s] + return (np.arange(n) - dfill(b))[i] diff --git a/src/triacontagon/deprecated/fastconv.py b/src/triacontagon/deprecated/fastconv.py index 038e2fc..3acc9ef 100644 --- a/src/triacontagon/deprecated/fastconv.py +++ b/src/triacontagon/deprecated/fastconv.py @@ -9,71 +9,11 @@ import torch from .weights import init_glorot from .normalize import _sparse_coo_tensor import types +from .util import _sparse_diag_cat, + _cat -def _sparse_diag_cat(matrices: List[torch.Tensor]): - if len(matrices) == 0: - raise ValueError('The list of matrices must be non-empty') - if not all(m.is_sparse for m in matrices): - raise ValueError('All matrices must be sparse') - - if not all(len(m.shape) == 2 for m in matrices): - raise ValueError('All matrices must be 2D') - - indices = [] - values = [] - row_offset = 0 - col_offset = 0 - - for m in matrices: - ind = m._indices().clone() - ind[0] += row_offset - ind[1] += col_offset - indices.append(ind) - values.append(m._values()) - row_offset += m.shape[0] - col_offset += m.shape[1] - - indices = torch.cat(indices, dim=1) - values = torch.cat(values) - - return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) - - -def _cat(matrices: List[torch.Tensor]): - if len(matrices) == 0: - raise ValueError('Empty list passed to _cat()') - - n = sum(a.is_sparse for a in matrices) - if n != 0 and n != len(matrices): - raise ValueError('All matrices must have the same layout (dense or sparse)') - - if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): - raise ValueError('All matrices must have the same dimensions apart from dimension 0') - - if not matrices[0].is_sparse: - return torch.cat(matrices) - - total_rows = sum(a.shape[0] for a in matrices) - indices = [] - values = [] - row_offset = 0 - - for a in matrices: - ind = a._indices().clone() - val = a._values() - ind[0] += row_offset - ind = ind.transpose(0, 1) - indices.append(ind) - values.append(val) - row_offset += a.shape[0] - - indices = torch.cat(indices).transpose(0, 1) - values = torch.cat(values) - - res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) - return res class FastGraphConv(torch.nn.Module): diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 60b7647..7f79719 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -12,6 +12,8 @@ from typing import List, \ Tuple from .data import Data, \ EdgeType +from .cumcount import cumcount +import time def fixed_unigram_candidate_sampler( @@ -69,14 +71,62 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ return edges_pos, degrees +def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: + indices = adj_mat.indices() + row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) + #print('indices[0]:', indices[0], count[indices[0]]) + row_count = row_count.index_add(0, indices[0], + torch.ones(indices.shape[1], dtype=torch.long)) + #print('count:', count) + max_true_classes = torch.max(row_count).item() + #print('max_true_classes:', max_true_classes) + true_classes = torch.full((adj_mat.shape[0], max_true_classes), + -1, dtype=torch.long) + + + # inv = torch.unique(indices[0], return_inverse=True) + + # indices = indices.copy() + # true_classes[indices[0], 0] = indices[1] + t = time.time() + cc = cumcount(indices[0].cpu().numpy()) + print('cumcount() took:', time.time() - t) + cc = torch.tensor(cc) + t = time.time() + true_classes[indices[0], cc] = indices[1] + print('assignment took:', time.time() - t) + + ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long) + for i in range(indices.shape[1]): + # print('looping...') + row = indices[0, i] + col = indices[1, i] + #print('row:', row, 'col:', col, 'count[row]:', count[row]) + true_classes[row, count[row]] = col + count[row] += 1 ''' + + t = time.time() + true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) + print('repeat_interleave() took:', time.time() - t) + + return true_classes + + def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: if not isinstance(adj_mat, torch.Tensor): raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) edges_pos, degrees = get_edges_and_degrees(adj_mat) + true_classes = get_true_classes(adj_mat) + # true_classes = edges_pos[:, 1].view(-1, 1) + # print('true_classes:', true_classes) + neg_neighbors = fixed_unigram_candidate_sampler( - edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) + true_classes, degrees, 0.75).to(adj_mat.device) + + print('neg_neighbors:', neg_neighbors) + edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1) ], 1) diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 6d7a155..75f056e 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -1,8 +1,41 @@ from triacontagon.data import Data -from triacontagon.sampling import negative_sample_adj_mat, \ +from triacontagon.sampling import get_true_classes, \ + negative_sample_adj_mat, \ negative_sample_data from triacontagon.decode import dedicom_decoder import torch +import time + + +def test_get_true_classes_01(): + adj_mat = torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 0, 1], + [1, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [0, 1, 0, 0, 0] + ], dtype=torch.float).to_sparse() + + true_classes = get_true_classes(adj_mat) + print('true_classes:', true_classes) + + assert torch.all(true_classes == torch.tensor([ + [1, 3], + [4, -1], + [0, 1], + [2, 4], + [1, -1] + ])) + + +def test_get_true_classes_02(): + adj_mat = torch.rand(2000, 2000).round().to_sparse() + + t = time.time() + true_classes = get_true_classes(adj_mat) + print('Elapsed:', time.time() - t) + + print('true_classes.shape:', true_classes.shape) def test_negative_sample_adj_mat_01(): @@ -16,7 +49,7 @@ def test_negative_sample_adj_mat_01(): print('adj_mat:', adj_mat) - adj_mat_neg = negative_sample_adj_mat(adj_mat) + adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse()) print('adj_mat_neg:', adj_mat_neg.to_dense())