diff --git a/src/triacontagon/__init__.py b/src/triacontagon/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/triacontagon/data.py b/src/triacontagon/data.py index 4505adf..22a4c89 100644 --- a/src/triacontagon/data.py +++ b/src/triacontagon/data.py @@ -4,206 +4,68 @@ # -from collections import defaultdict -from dataclasses import dataclass, field -import torch -from typing import List, \ - Dict, \ +from dataclasses import dataclass +from typing import Callable, \ Tuple, \ - Any, \ - Type -from .decode import DEDICOMDecoder, \ - BilinearDecoder -import numpy as np - - -def _equal(x: torch.Tensor, y: torch.Tensor): - if x.is_sparse ^ y.is_sparse: - raise ValueError('Cannot mix sparse and dense tensors') - - if not x.is_sparse: - return (x == y) - - return ((x - y).coalesce().values() == 0) + List +import types +from .util import _nonzero_sum @dataclass -class NodeType(object): - name: str - count: int +class DecodingMatrices(object): + global_interaction: torch.Tensor + local_variation: torch.Tensor @dataclass -class RelationTypeBase(object): +class VertexType(object): name: str - node_type_row: int - node_type_column: int - adjacency_matrix: torch.Tensor - adjacency_matrix_backward: torch.Tensor - - -@dataclass -class RelationType(RelationTypeBase): - pass + count: int @dataclass -class RelationFamilyBase(object): - data: 'Data' +class EdgeType(object): name: str - node_type_row: int - node_type_column: int - is_symmetric: bool - decoder_class: Type - - -@dataclass -class RelationFamily(RelationFamilyBase): - relation_types: List[RelationType] = None - - def __post_init__(self) -> None: - if not self.is_symmetric and \ - self.decoder_class != DEDICOMDecoder and \ - self.decoder_class != BilinearDecoder: - raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') - - self.relation_types = [] - - def add_relation_type(self, - name: str, adjacency_matrix: torch.Tensor, - adjacency_matrix_backward: torch.Tensor = None) -> None: - - name = str(name) - node_type_row = self.node_type_row - node_type_column = self.node_type_column - - if adjacency_matrix is None and adjacency_matrix_backward is None: - raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') - - if adjacency_matrix is not None and \ - not isinstance(adjacency_matrix, torch.Tensor): - raise ValueError('adjacency_matrix must be a torch.Tensor') - - if adjacency_matrix_backward is not None \ - and not isinstance(adjacency_matrix_backward, torch.Tensor): - raise ValueError('adjacency_matrix_backward must be a torch.Tensor') - - if adjacency_matrix is not None and \ - adjacency_matrix.shape != (self.data.node_types[node_type_row].count, - self.data.node_types[node_type_column].count): - raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') - - if adjacency_matrix_backward is not None and \ - adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, - self.data.node_types[node_type_row].count): - raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') - - if node_type_row == node_type_column and \ - adjacency_matrix_backward is not None: - raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') - - if self.is_symmetric and adjacency_matrix_backward is not None: - raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') - - if self.is_symmetric and node_type_row == node_type_column and \ - not torch.all(_equal(adjacency_matrix, - adjacency_matrix.transpose(0, 1))): - raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') - - if not self.is_symmetric and node_type_row != node_type_column and \ - adjacency_matrix_backward is None: - raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') - - if self.is_symmetric and node_type_row != node_type_column: - adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) - - self.relation_types.append(RelationType(name, - node_type_row, node_type_column, - adjacency_matrix, adjacency_matrix_backward)) - - def node_name(self, index): - return self.data.node_types[index].name - - def __repr__(self): - s = 'Relation family %s' % self.name - - for r in self.relation_types: - s += '\n - %s%s' % (r.name, ' (two-way)' \ - if (r.adjacency_matrix is not None \ - and r.adjacency_matrix_backward is not None) \ - or self.node_type_row == self.node_type_column \ - else '%s <- %s' % (self.node_name(self.node_type_row), - self.node_name(self.node_type_column))) - - return s - - def repr_indented(self): - s = ' - %s' % self.name - - for r in self.relation_types: - s += '\n - %s%s' % (r.name, ' (two-way)' \ - if (r.adjacency_matrix is not None \ - and r.adjacency_matrix_backward is not None) \ - or self.node_type_row == self.node_type_column \ - else '%s <- %s' % (self.node_name(self.node_type_row), - self.node_name(self.node_type_column))) - - return s + vertex_type_row: int + vertex_type_column: int + adjacency_matrices: List[torch.Tensor] + decoder_factory: Callable[[], DecodingMatrices] + total_connectivity: torch.Tensor class Data(object): - node_types: List[NodeType] - relation_families: List[RelationFamily] + vertex_types: List[VertexType] + edge_types: List[EdgeType] def __init__(self) -> None: - self.node_types = [] - self.relation_families = [] + self.vertex_types = [] + self.edge_types = {} - def add_node_type(self, name: str, count: int) -> None: + def add_vertex_type(self, name: str, count: int) -> None: name = str(name) count = int(count) if not name: - raise ValueError('You must provide a non-empty node type name') + raise ValueError('You must provide a non-empty vertex type name') if count <= 0: - raise ValueError('You must provide a positive node count') - self.node_types.append(NodeType(name, count)) + raise ValueError('You must provide a positive vertex count') + self.vertex_types.append(VertexType(name, count)) - def add_relation_family(self, name: str, node_type_row: int, - node_type_column: int, is_symmetric: bool, - decoder_class: Type = DEDICOMDecoder): + def add_edge_type(self, name: str, + vertex_type_row: int, vertex_type_column: int, + adjacency_matrices: List[torch.Tensor], + decoder_factory: Callable[[], DecodingMatrices]) -> None: name = str(name) - node_type_row = int(node_type_row) - node_type_column = int(node_type_column) - is_symmetric = bool(is_symmetric) - - if node_type_row < 0 or node_type_row >= len(self.node_types): - raise ValueError('node_type_row outside of the valid range of node types') - - if node_type_column < 0 or node_type_column >= len(self.node_types): - raise ValueError('node_type_column outside of the valid range of node types') - - fam = RelationFamily(self, name, node_type_row, node_type_column, - is_symmetric, decoder_class) - self.relation_families.append(fam) - - return fam - - def __repr__(self): - n = len(self.node_types) - if n == 0: - return 'Empty Icosagon Data' - s = '' - s += 'Icosagon Data with:\n' - s += '- ' + str(n) + ' node type(s):\n' - for nt in self.node_types: - s += ' - ' + nt.name + '\n' - if len(self.relation_families) == 0: - s += '- No relation families\n' - return s.strip() - - s += '- %d relation families:\n' % len(self.relation_families) - for fam in self.relation_families: - s += fam.repr_indented() + '\n' - - return s.strip() + vertex_type_row = int(vertex_type_row) + vertex_type_column = int(vertex_type_column) + if not isinstance(adjacency_matrices, list): + raise TypeError('adjacency_matrices must be a list of tensors') + if not isinstance(decoder_factory, types.FunctionType): + raise TypeError('decoder_factory must be a function') + if (vertex_type_row, vertex_type_column) in self.edge_types: + raise KeyError('Edge type for given combination of row and column already exists') + total_connectivity = _nonzero_sum(adjacency_matrices) + self.edges_types[vertex_type_row, vertex_type_column] = \ + VertexType(name, vertex_type_row, vertex_type_column, + adjacency_matrices, decoder_factory, total_connectivity) diff --git a/src/triacontagon/decode.py b/src/triacontagon/decode.py index 00df8b2..25ae822 100644 --- a/src/triacontagon/decode.py +++ b/src/triacontagon/decode.py @@ -7,117 +7,47 @@ import torch from .weights import init_glorot from .dropout import dropout +from typing import Tuple, \ + List -class DEDICOMDecoder(torch.nn.Module): - """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, keep_prob=1., - activation=torch.sigmoid, **kwargs): +def dedicom_decoder(input_dim: int, num_relation_types: int) -> + Tuple[torch.Tensor, List[torch.Tensor]]: - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_relation_types = num_relation_types - self.keep_prob = keep_prob - self.activation = activation + global_interaction = init_glorot(input_dim, input_dim) + local_variation = [ + torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ + for _ in range(num_relation_types) + ] + return (global_interaction, local_variation) - self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) - self.local_variation = torch.nn.ParameterList([ - torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ - for _ in range(num_relation_types) - ]) - def forward(self, inputs_row, inputs_col, relation_index): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) +def dist_mult_decoder(input_dim: int, num_relation_types: int) -> + Tuple[torch.Tensor, List[torch.Tensor]]: - relation = torch.diag(self.local_variation[relation_index]) + global_interaction = torch.eye(input_dim, input_dim) + local_variation = [ + torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ + for _ in range(num_relation_types) + ] + return (global_interaction, local_variation) - product1 = torch.mm(inputs_row, relation) - product2 = torch.mm(product1, self.global_interaction) - product3 = torch.mm(product2, relation) - rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - return self.activation(rec) +def bilinear_decoder(input_dim: int, num_relation_types: int) -> + Tuple[torch.Tensor, List[torch.Tensor]]: + global_interaction = torch.eye(input_dim, input_dim) + local_variation = [ + init_glorot(input_dim, input_dim) \ + for _ in range(num_relation_types) + ] + return (global_interaction, local_variation) -class DistMultDecoder(torch.nn.Module): - """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, keep_prob=1., - activation=torch.sigmoid, **kwargs): - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_relation_types = num_relation_types - self.keep_prob = keep_prob - self.activation = activation +def inner_product_decoder(input_dim: int, num_relation_types: int) -> + Tuple[torch.Tensor, List[torch.Tensor]]: - self.relation = torch.nn.ParameterList([ - torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ - for _ in range(num_relation_types) - ]) - - def forward(self, inputs_row, inputs_col, relation_index): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) - - relation = torch.diag(self.relation[relation_index]) - - intermediate_product = torch.mm(inputs_row, relation) - rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - - return self.activation(rec) - - -class BilinearDecoder(torch.nn.Module): - """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, keep_prob=1., - activation=torch.sigmoid, **kwargs): - - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_relation_types = num_relation_types - self.keep_prob = keep_prob - self.activation = activation - - self.relation = torch.nn.ParameterList([ - torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ - for _ in range(num_relation_types) - ]) - - def forward(self, inputs_row, inputs_col, relation_index): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) - - intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) - rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - - return self.activation(rec) - - -class InnerProductDecoder(torch.nn.Module): - """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, keep_prob=1., - activation=torch.sigmoid, **kwargs): - - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_relation_types = num_relation_types - self.keep_prob = keep_prob - self.activation = activation - - - def forward(self, inputs_row, inputs_col, _): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) - - rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - - return self.activation(rec) + global_interaction = torch.eye(input_dim, input_dim) + local_variation = torch.eye(input_dim, input_dim) + local_variation = [ local_variation ] * num_relation_types + return (global_interaction, local_variation) diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py new file mode 100644 index 0000000..1b98931 --- /dev/null +++ b/src/triacontagon/model.py @@ -0,0 +1,129 @@ +from .data import Data, \ + EdgeType +import torch +from dataclasses import dataclass +from .weights import init_glorot +import types +from typing import List, \ + Dict, \ + Callable +from .util import _sparse_coo_tensor + + +@dataclass +class TrainingBatch(object): + vertex_type_row: int + vertex_type_column: int + relation_type_index: int + edges: torch.Tensor + + +class Model(torch.nn.Module): + def __init__(self, data: Data, layer_dimensions: List[int], + keep_prob: float, + conv_activation: Callable[[torch.Tensor], torch.Tensor], + dec_activation: Callable[[torch.Tensor], torch.Tensor], + **kwargs) -> None: + super().__init__(**kwargs) + + if not isinstance(data, Data): + raise TypeError('data must be an instance of Data') + + if not isinstance(conv_activation, types.FunctionType): + raise TypeError('conv_activation must be a function') + + if not isinstance(dec_activation, types.FunctionType): + raise TypeError('dec_activation must be a function') + + self.data = data + self.layer_dimensions = list(layer_dimensions) + self.keep_prob = float(keep_prob) + self.conv_activation = conv_activation + self.dec_activation = dec_activation + + self.conv_weights = None + self.dec_weights = None + self.build() + + + def build(self) -> None: + self.conv_weights = torch.nn.ParameterDict() + for i in range(len(self.layer_dimensions) - 1): + in_dimension = self.layer_dimensions[i] + out_dimension = self.layer_dimensions[i + 1] + + for _, et in self.data.edge_types.items(): + weight = init_glorot(in_dimension, out_dimension) + self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \ + torch.nn.Parameter(weight) + + self.dec_weights = torch.nn.ParameterDict() + for _, et in self.data.edge_types.items(): + global_interaction, local_variation = \ + et.decoder_factory(self.layer_dimensions[-1], + len(et.adjacency_matrices)) + self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \ + torch.nn.ParameterList([ + torch.nn.Parameter(global_interaction), + torch.nn.Parameter(local_variation) + ]) + + def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, + rows: torch.Tensor) -> torch.Tensor: + + adj_mat = adjacency_matrix.coalesce() + adj_mat = torch.index_select(adj_mat, 0, rows) + + adj_mat = adj_mat.coalesce() + indices = adj_mat.indices() + indices[0] = rows + + adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) + + def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, + batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: + + col = batch.vertex_type_column + rows = batch.edges[:, 0] + + columns = batch.edges[:, 1].sum(dim=0).flatten() + columns = torch.nonzero(columns) + + for i in range(len(self.layer_dimensions) - 1): + + columns = + + + def temporary_adjacency_matrices(self, batch: TrainingBatch) -> + Dict[Tuple[int, int], List[List[torch.Tensor]]]: + + col = batch.vertex_type_column + batch.edges[:, 1] + + res = {} + + for _, et in self.data.edge_types.items(): + sum_nonzero = _nonzero_sum(et.adjacency_matrices) + res[et.vertex_type_row, et.vertex_type_column] = \ + [ self.temporary_adjacency_matrix(adj_mat, batch, + et.total_connectivity) \ + for adj_mat in et.adjacency_matrices ] + + return res + + def forward(self, initial_repr: List[torch.Tensor], + batch: TrainingBatch) -> torch.Tensor: + + if not isinstance(initial_repr, list): + raise TypeError('initial_repr must be a list') + + if len(initial_repr) != len(self.data.vertex_types): + raise ValueError('initial_repr must contain representations for all vertex types') + + if not isinstance(batch, TrainingBatch): + raise TypeError('batch must be an instance of TrainingBatch') + + adj_matrices = self.temporary_adjacency_matrices(batch) + + row_vertices = initial_repr[batch.vertex_type_row] + column_vertices = initial_repr[batch.vertex_type_column] diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py new file mode 100644 index 0000000..2367b06 --- /dev/null +++ b/src/triacontagon/util.py @@ -0,0 +1,174 @@ +import torch +from typing import List, \ + Set +import time + + +def _equal(x: torch.Tensor, y: torch.Tensor): + if x.is_sparse ^ y.is_sparse: + raise ValueError('Cannot mix sparse and dense tensors') + + if not x.is_sparse: + return (x == y) + + return ((x - y).coalesce().values() == 0) + + +def _sparse_coo_tensor(indices, values, size): + ctor = { torch.float32: torch.sparse.FloatTensor, + torch.float32: torch.sparse.DoubleTensor, + torch.uint8: torch.sparse.ByteTensor, + torch.long: torch.sparse.LongTensor, + torch.int: torch.sparse.IntTensor, + torch.short: torch.sparse.ShortTensor, + torch.bool: torch.sparse.ByteTensor }[values.dtype] + return ctor(indices, values, size) + + +def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): + if len(adjacency_matrices) == 0: + raise ValueError('adjacency_matrices must be non-empty') + + if not all([x.is_sparse for x in adjacency_matrices]): + raise ValueError('All adjacency matrices must be sparse') + + indices = [ x.indices() for x in adjacency_matrices ] + indices = torch.cat(indices, dim=1) + + values = torch.ones(indices.shape[1]) + res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape) + res = res.coalesce() + + indices = res.indices() + res = _sparse_coo_tensor(indices, + torch.ones(indices.shape[1], dtype=torch.uint8)) + + return res + + +def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, + rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor: + + if not adjacency_matrix.is_sparse: + raise ValueError('adjacency_matrix must be sparse') + + if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types: + raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types') + + t = time.time() + rows = [ rows + row_vertex_count * i \ + for i in range(num_relation_types) ] + print('rows took:', time.time() - t) + t = time.time() + rows = torch.cat(rows) + print('cat took:', time.time() - t) + # print('rows:', rows) + rows = set(rows.tolist()) + # print('rows:', rows) + + t = time.time() + adj_mat = adjacency_matrix.coalesce() + indices = adj_mat.indices() + values = adj_mat.values() + print('indices[0]:', indices[0]) + print('indices[0][1]:', indices[0][1], indices[0][1] in rows) + selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) + # print('selection:', selection) + selection = torch.nonzero(selection, as_tuple=True)[0] + # print('selection:', selection) + indices = indices[:, selection] + values = values[selection] + print('"index_select()" took:', time.time() - t) + + t = time.time() + res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) + print('_sparse_coo_tensor() took:', time.time() - t) + + return res + + # t = time.time() + # adj_mat = torch.index_select(adjacency_matrix, 0, rows) + # print('index_select took:', time.time() - t) + + t = time.time() + adj_mat = adj_mat.coalesce() + print('coalesce() took:', time.time() - t) + indices = adj_mat.indices() + + # print('indices:', indices) + + values = adj_mat.values() + t = time.time() + indices[0] = rows[indices[0]] + print('Lookup took:', time.time() - t) + + t = time.time() + adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) + print('_sparse_coo_tensor() took:', time.time() - t) + + return adj_mat + + +def _sparse_diag_cat(matrices: List[torch.Tensor]): + if len(matrices) == 0: + raise ValueError('The list of matrices must be non-empty') + + if not all(m.is_sparse for m in matrices): + raise ValueError('All matrices must be sparse') + + if not all(len(m.shape) == 2 for m in matrices): + raise ValueError('All matrices must be 2D') + + indices = [] + values = [] + row_offset = 0 + col_offset = 0 + + for m in matrices: + ind = m._indices().clone() + ind[0] += row_offset + ind[1] += col_offset + indices.append(ind) + values.append(m._values()) + row_offset += m.shape[0] + col_offset += m.shape[1] + + indices = torch.cat(indices, dim=1) + values = torch.cat(values) + + return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) + + +def _cat(matrices: List[torch.Tensor]): + if len(matrices) == 0: + raise ValueError('Empty list passed to _cat()') + + n = sum(a.is_sparse for a in matrices) + if n != 0 and n != len(matrices): + raise ValueError('All matrices must have the same layout (dense or sparse)') + + if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): + raise ValueError('All matrices must have the same dimensions apart from dimension 0') + + if not matrices[0].is_sparse: + return torch.cat(matrices) + + total_rows = sum(a.shape[0] for a in matrices) + indices = [] + values = [] + row_offset = 0 + + for a in matrices: + ind = a._indices().clone() + val = a._values() + ind[0] += row_offset + ind = ind.transpose(0, 1) + indices.append(ind) + values.append(val) + row_offset += a.shape[0] + + indices = torch.cat(indices).transpose(0, 1) + values = torch.cat(values) + + res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) + return res diff --git a/tests/triacontagon/test_util.py b/tests/triacontagon/test_util.py new file mode 100644 index 0000000..5937535 --- /dev/null +++ b/tests/triacontagon/test_util.py @@ -0,0 +1,95 @@ +from triacontagon.util import \ + _clear_adjacency_matrix_except_rows, \ + _sparse_diag_cat, \ + _equal +import torch +import time + + +def test_clear_adjacency_matrix_except_rows_01(): + adj_mat = torch.tensor([ + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 1], + [1, 0, 1, 0, 0], + [1, 1, 0, 0, 0] + ], dtype=torch.uint8).to_sparse() + + adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ]) + + res = _clear_adjacency_matrix_except_rows(adj_mat, + torch.tensor([1, 3]), 4, 2) + + res = res.to_dense() + + truth = torch.tensor([ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0] + ], dtype=torch.uint8) + + print('res:', res) + + assert torch.all(res == truth) + + +def test_clear_adjacency_matrix_except_rows_02(): + adj_mat = torch.rand(6, 10).round().to(torch.uint8) + + t = time.time() + res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130) + print('_sparse_diag_cat() took:', time.time() - t) + + t = time.time() + res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), + 6, 130) + print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) + + adj_mat[0] = adj_mat[2] = adj_mat[4] = \ + torch.zeros(10) + truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130) + + assert _equal(res, truth).all() + + +def test_clear_adjacency_matrix_except_rows_03(): + adj_mat = torch.rand(6, 10).round().to(torch.uint8) + + t = time.time() + res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) + print('_sparse_diag_cat() took:', time.time() - t) + + t = time.time() + res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), + 6, 1300) + print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) + + adj_mat[0] = adj_mat[2] = adj_mat[4] = \ + torch.zeros(10) + truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) + + assert _equal(res, truth).all() + + +def test_clear_adjacency_matrix_except_rows_04(): + adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) + + t = time.time() + res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) + print('_sparse_diag_cat() took:', time.time() - t) + + t = time.time() + res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), + 2000, 1300) + print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) + + adj_mat[0] = adj_mat[2] = adj_mat[4] = \ + torch.zeros(2000) + adj_mat[6:] = torch.zeros(2000) + truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) + + assert _equal(res, truth).all()