| @@ -4,206 +4,68 @@ | |||
| # | |||
| from collections import defaultdict | |||
| from dataclasses import dataclass, field | |||
| import torch | |||
| from typing import List, \ | |||
| Dict, \ | |||
| from dataclasses import dataclass | |||
| from typing import Callable, \ | |||
| Tuple, \ | |||
| Any, \ | |||
| Type | |||
| from .decode import DEDICOMDecoder, \ | |||
| BilinearDecoder | |||
| import numpy as np | |||
| def _equal(x: torch.Tensor, y: torch.Tensor): | |||
| if x.is_sparse ^ y.is_sparse: | |||
| raise ValueError('Cannot mix sparse and dense tensors') | |||
| if not x.is_sparse: | |||
| return (x == y) | |||
| return ((x - y).coalesce().values() == 0) | |||
| List | |||
| import types | |||
| from .util import _nonzero_sum | |||
| @dataclass | |||
| class NodeType(object): | |||
| name: str | |||
| count: int | |||
| class DecodingMatrices(object): | |||
| global_interaction: torch.Tensor | |||
| local_variation: torch.Tensor | |||
| @dataclass | |||
| class RelationTypeBase(object): | |||
| class VertexType(object): | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| adjacency_matrix: torch.Tensor | |||
| adjacency_matrix_backward: torch.Tensor | |||
| @dataclass | |||
| class RelationType(RelationTypeBase): | |||
| pass | |||
| count: int | |||
| @dataclass | |||
| class RelationFamilyBase(object): | |||
| data: 'Data' | |||
| class EdgeType(object): | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| is_symmetric: bool | |||
| decoder_class: Type | |||
| @dataclass | |||
| class RelationFamily(RelationFamilyBase): | |||
| relation_types: List[RelationType] = None | |||
| def __post_init__(self) -> None: | |||
| if not self.is_symmetric and \ | |||
| self.decoder_class != DEDICOMDecoder and \ | |||
| self.decoder_class != BilinearDecoder: | |||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
| self.relation_types = [] | |||
| def add_relation_type(self, | |||
| name: str, adjacency_matrix: torch.Tensor, | |||
| adjacency_matrix_backward: torch.Tensor = None) -> None: | |||
| name = str(name) | |||
| node_type_row = self.node_type_row | |||
| node_type_column = self.node_type_column | |||
| if adjacency_matrix is None and adjacency_matrix_backward is None: | |||
| raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||
| if adjacency_matrix is not None and \ | |||
| not isinstance(adjacency_matrix, torch.Tensor): | |||
| raise ValueError('adjacency_matrix must be a torch.Tensor') | |||
| if adjacency_matrix_backward is not None \ | |||
| and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||
| raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||
| if adjacency_matrix is not None and \ | |||
| adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
| self.data.node_types[node_type_column].count): | |||
| raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||
| if adjacency_matrix_backward is not None and \ | |||
| adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||
| self.data.node_types[node_type_row].count): | |||
| raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||
| if node_type_row == node_type_column and \ | |||
| adjacency_matrix_backward is not None: | |||
| raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||
| if self.is_symmetric and adjacency_matrix_backward is not None: | |||
| raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') | |||
| if self.is_symmetric and node_type_row == node_type_column and \ | |||
| not torch.all(_equal(adjacency_matrix, | |||
| adjacency_matrix.transpose(0, 1))): | |||
| raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||
| if not self.is_symmetric and node_type_row != node_type_column and \ | |||
| adjacency_matrix_backward is None: | |||
| raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') | |||
| if self.is_symmetric and node_type_row != node_type_column: | |||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
| self.relation_types.append(RelationType(name, | |||
| node_type_row, node_type_column, | |||
| adjacency_matrix, adjacency_matrix_backward)) | |||
| def node_name(self, index): | |||
| return self.data.node_types[index].name | |||
| def __repr__(self): | |||
| s = 'Relation family %s' % self.name | |||
| for r in self.relation_types: | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
| if (r.adjacency_matrix is not None \ | |||
| and r.adjacency_matrix_backward is not None) \ | |||
| or self.node_type_row == self.node_type_column \ | |||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||
| self.node_name(self.node_type_column))) | |||
| return s | |||
| def repr_indented(self): | |||
| s = ' - %s' % self.name | |||
| for r in self.relation_types: | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
| if (r.adjacency_matrix is not None \ | |||
| and r.adjacency_matrix_backward is not None) \ | |||
| or self.node_type_row == self.node_type_column \ | |||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||
| self.node_name(self.node_type_column))) | |||
| return s | |||
| vertex_type_row: int | |||
| vertex_type_column: int | |||
| adjacency_matrices: List[torch.Tensor] | |||
| decoder_factory: Callable[[], DecodingMatrices] | |||
| total_connectivity: torch.Tensor | |||
| class Data(object): | |||
| node_types: List[NodeType] | |||
| relation_families: List[RelationFamily] | |||
| vertex_types: List[VertexType] | |||
| edge_types: List[EdgeType] | |||
| def __init__(self) -> None: | |||
| self.node_types = [] | |||
| self.relation_families = [] | |||
| self.vertex_types = [] | |||
| self.edge_types = {} | |||
| def add_node_type(self, name: str, count: int) -> None: | |||
| def add_vertex_type(self, name: str, count: int) -> None: | |||
| name = str(name) | |||
| count = int(count) | |||
| if not name: | |||
| raise ValueError('You must provide a non-empty node type name') | |||
| raise ValueError('You must provide a non-empty vertex type name') | |||
| if count <= 0: | |||
| raise ValueError('You must provide a positive node count') | |||
| self.node_types.append(NodeType(name, count)) | |||
| raise ValueError('You must provide a positive vertex count') | |||
| self.vertex_types.append(VertexType(name, count)) | |||
| def add_relation_family(self, name: str, node_type_row: int, | |||
| node_type_column: int, is_symmetric: bool, | |||
| decoder_class: Type = DEDICOMDecoder): | |||
| def add_edge_type(self, name: str, | |||
| vertex_type_row: int, vertex_type_column: int, | |||
| adjacency_matrices: List[torch.Tensor], | |||
| decoder_factory: Callable[[], DecodingMatrices]) -> None: | |||
| name = str(name) | |||
| node_type_row = int(node_type_row) | |||
| node_type_column = int(node_type_column) | |||
| is_symmetric = bool(is_symmetric) | |||
| if node_type_row < 0 or node_type_row >= len(self.node_types): | |||
| raise ValueError('node_type_row outside of the valid range of node types') | |||
| if node_type_column < 0 or node_type_column >= len(self.node_types): | |||
| raise ValueError('node_type_column outside of the valid range of node types') | |||
| fam = RelationFamily(self, name, node_type_row, node_type_column, | |||
| is_symmetric, decoder_class) | |||
| self.relation_families.append(fam) | |||
| return fam | |||
| def __repr__(self): | |||
| n = len(self.node_types) | |||
| if n == 0: | |||
| return 'Empty Icosagon Data' | |||
| s = '' | |||
| s += 'Icosagon Data with:\n' | |||
| s += '- ' + str(n) + ' node type(s):\n' | |||
| for nt in self.node_types: | |||
| s += ' - ' + nt.name + '\n' | |||
| if len(self.relation_families) == 0: | |||
| s += '- No relation families\n' | |||
| return s.strip() | |||
| s += '- %d relation families:\n' % len(self.relation_families) | |||
| for fam in self.relation_families: | |||
| s += fam.repr_indented() + '\n' | |||
| return s.strip() | |||
| vertex_type_row = int(vertex_type_row) | |||
| vertex_type_column = int(vertex_type_column) | |||
| if not isinstance(adjacency_matrices, list): | |||
| raise TypeError('adjacency_matrices must be a list of tensors') | |||
| if not isinstance(decoder_factory, types.FunctionType): | |||
| raise TypeError('decoder_factory must be a function') | |||
| if (vertex_type_row, vertex_type_column) in self.edge_types: | |||
| raise KeyError('Edge type for given combination of row and column already exists') | |||
| total_connectivity = _nonzero_sum(adjacency_matrices) | |||
| self.edges_types[vertex_type_row, vertex_type_column] = \ | |||
| VertexType(name, vertex_type_row, vertex_type_column, | |||
| adjacency_matrices, decoder_factory, total_connectivity) | |||
| @@ -7,117 +7,47 @@ | |||
| import torch | |||
| from .weights import init_glorot | |||
| from .dropout import dropout | |||
| from typing import Tuple, \ | |||
| List | |||
| class DEDICOMDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| global_interaction = init_glorot(input_dim, input_dim) | |||
| local_variation = [ | |||
| torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ | |||
| for _ in range(num_relation_types) | |||
| ] | |||
| return (global_interaction, local_variation) | |||
| self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||
| self.local_variation = torch.nn.ParameterList([ | |||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||
| for _ in range(num_relation_types) | |||
| ]) | |||
| def forward(self, inputs_row, inputs_col, relation_index): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| def dist_mult_decoder(input_dim: int, num_relation_types: int) -> | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| relation = torch.diag(self.local_variation[relation_index]) | |||
| global_interaction = torch.eye(input_dim, input_dim) | |||
| local_variation = [ | |||
| torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ | |||
| for _ in range(num_relation_types) | |||
| ] | |||
| return (global_interaction, local_variation) | |||
| product1 = torch.mm(inputs_row, relation) | |||
| product2 = torch.mm(product1, self.global_interaction) | |||
| product3 = torch.mm(product2, relation) | |||
| rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| global_interaction = torch.eye(input_dim, input_dim) | |||
| local_variation = [ | |||
| init_glorot(input_dim, input_dim) \ | |||
| for _ in range(num_relation_types) | |||
| ] | |||
| return (global_interaction, local_variation) | |||
| class DistMultDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| def inner_product_decoder(input_dim: int, num_relation_types: int) -> | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| self.relation = torch.nn.ParameterList([ | |||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||
| for _ in range(num_relation_types) | |||
| ]) | |||
| def forward(self, inputs_row, inputs_col, relation_index): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| relation = torch.diag(self.relation[relation_index]) | |||
| intermediate_product = torch.mm(inputs_row, relation) | |||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| class BilinearDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.relation = torch.nn.ParameterList([ | |||
| torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||
| for _ in range(num_relation_types) | |||
| ]) | |||
| def forward(self, inputs_row, inputs_col, relation_index): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) | |||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| class InnerProductDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| def forward(self, inputs_row, inputs_col, _): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| global_interaction = torch.eye(input_dim, input_dim) | |||
| local_variation = torch.eye(input_dim, input_dim) | |||
| local_variation = [ local_variation ] * num_relation_types | |||
| return (global_interaction, local_variation) | |||
| @@ -0,0 +1,129 @@ | |||
| from .data import Data, \ | |||
| EdgeType | |||
| import torch | |||
| from dataclasses import dataclass | |||
| from .weights import init_glorot | |||
| import types | |||
| from typing import List, \ | |||
| Dict, \ | |||
| Callable | |||
| from .util import _sparse_coo_tensor | |||
| @dataclass | |||
| class TrainingBatch(object): | |||
| vertex_type_row: int | |||
| vertex_type_column: int | |||
| relation_type_index: int | |||
| edges: torch.Tensor | |||
| class Model(torch.nn.Module): | |||
| def __init__(self, data: Data, layer_dimensions: List[int], | |||
| keep_prob: float, | |||
| conv_activation: Callable[[torch.Tensor], torch.Tensor], | |||
| dec_activation: Callable[[torch.Tensor], torch.Tensor], | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| if not isinstance(data, Data): | |||
| raise TypeError('data must be an instance of Data') | |||
| if not isinstance(conv_activation, types.FunctionType): | |||
| raise TypeError('conv_activation must be a function') | |||
| if not isinstance(dec_activation, types.FunctionType): | |||
| raise TypeError('dec_activation must be a function') | |||
| self.data = data | |||
| self.layer_dimensions = list(layer_dimensions) | |||
| self.keep_prob = float(keep_prob) | |||
| self.conv_activation = conv_activation | |||
| self.dec_activation = dec_activation | |||
| self.conv_weights = None | |||
| self.dec_weights = None | |||
| self.build() | |||
| def build(self) -> None: | |||
| self.conv_weights = torch.nn.ParameterDict() | |||
| for i in range(len(self.layer_dimensions) - 1): | |||
| in_dimension = self.layer_dimensions[i] | |||
| out_dimension = self.layer_dimensions[i + 1] | |||
| for _, et in self.data.edge_types.items(): | |||
| weight = init_glorot(in_dimension, out_dimension) | |||
| self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \ | |||
| torch.nn.Parameter(weight) | |||
| self.dec_weights = torch.nn.ParameterDict() | |||
| for _, et in self.data.edge_types.items(): | |||
| global_interaction, local_variation = \ | |||
| et.decoder_factory(self.layer_dimensions[-1], | |||
| len(et.adjacency_matrices)) | |||
| self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \ | |||
| torch.nn.ParameterList([ | |||
| torch.nn.Parameter(global_interaction), | |||
| torch.nn.Parameter(local_variation) | |||
| ]) | |||
| def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, | |||
| rows: torch.Tensor) -> torch.Tensor: | |||
| adj_mat = adjacency_matrix.coalesce() | |||
| adj_mat = torch.index_select(adj_mat, 0, rows) | |||
| adj_mat = adj_mat.coalesce() | |||
| indices = adj_mat.indices() | |||
| indices[0] = rows | |||
| adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) | |||
| def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | |||
| batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | |||
| col = batch.vertex_type_column | |||
| rows = batch.edges[:, 0] | |||
| columns = batch.edges[:, 1].sum(dim=0).flatten() | |||
| columns = torch.nonzero(columns) | |||
| for i in range(len(self.layer_dimensions) - 1): | |||
| columns = | |||
| def temporary_adjacency_matrices(self, batch: TrainingBatch) -> | |||
| Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||
| col = batch.vertex_type_column | |||
| batch.edges[:, 1] | |||
| res = {} | |||
| for _, et in self.data.edge_types.items(): | |||
| sum_nonzero = _nonzero_sum(et.adjacency_matrices) | |||
| res[et.vertex_type_row, et.vertex_type_column] = \ | |||
| [ self.temporary_adjacency_matrix(adj_mat, batch, | |||
| et.total_connectivity) \ | |||
| for adj_mat in et.adjacency_matrices ] | |||
| return res | |||
| def forward(self, initial_repr: List[torch.Tensor], | |||
| batch: TrainingBatch) -> torch.Tensor: | |||
| if not isinstance(initial_repr, list): | |||
| raise TypeError('initial_repr must be a list') | |||
| if len(initial_repr) != len(self.data.vertex_types): | |||
| raise ValueError('initial_repr must contain representations for all vertex types') | |||
| if not isinstance(batch, TrainingBatch): | |||
| raise TypeError('batch must be an instance of TrainingBatch') | |||
| adj_matrices = self.temporary_adjacency_matrices(batch) | |||
| row_vertices = initial_repr[batch.vertex_type_row] | |||
| column_vertices = initial_repr[batch.vertex_type_column] | |||
| @@ -0,0 +1,174 @@ | |||
| import torch | |||
| from typing import List, \ | |||
| Set | |||
| import time | |||
| def _equal(x: torch.Tensor, y: torch.Tensor): | |||
| if x.is_sparse ^ y.is_sparse: | |||
| raise ValueError('Cannot mix sparse and dense tensors') | |||
| if not x.is_sparse: | |||
| return (x == y) | |||
| return ((x - y).coalesce().values() == 0) | |||
| def _sparse_coo_tensor(indices, values, size): | |||
| ctor = { torch.float32: torch.sparse.FloatTensor, | |||
| torch.float32: torch.sparse.DoubleTensor, | |||
| torch.uint8: torch.sparse.ByteTensor, | |||
| torch.long: torch.sparse.LongTensor, | |||
| torch.int: torch.sparse.IntTensor, | |||
| torch.short: torch.sparse.ShortTensor, | |||
| torch.bool: torch.sparse.ByteTensor }[values.dtype] | |||
| return ctor(indices, values, size) | |||
| def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): | |||
| if len(adjacency_matrices) == 0: | |||
| raise ValueError('adjacency_matrices must be non-empty') | |||
| if not all([x.is_sparse for x in adjacency_matrices]): | |||
| raise ValueError('All adjacency matrices must be sparse') | |||
| indices = [ x.indices() for x in adjacency_matrices ] | |||
| indices = torch.cat(indices, dim=1) | |||
| values = torch.ones(indices.shape[1]) | |||
| res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape) | |||
| res = res.coalesce() | |||
| indices = res.indices() | |||
| res = _sparse_coo_tensor(indices, | |||
| torch.ones(indices.shape[1], dtype=torch.uint8)) | |||
| return res | |||
| def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, | |||
| rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor: | |||
| if not adjacency_matrix.is_sparse: | |||
| raise ValueError('adjacency_matrix must be sparse') | |||
| if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types: | |||
| raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types') | |||
| t = time.time() | |||
| rows = [ rows + row_vertex_count * i \ | |||
| for i in range(num_relation_types) ] | |||
| print('rows took:', time.time() - t) | |||
| t = time.time() | |||
| rows = torch.cat(rows) | |||
| print('cat took:', time.time() - t) | |||
| # print('rows:', rows) | |||
| rows = set(rows.tolist()) | |||
| # print('rows:', rows) | |||
| t = time.time() | |||
| adj_mat = adjacency_matrix.coalesce() | |||
| indices = adj_mat.indices() | |||
| values = adj_mat.values() | |||
| print('indices[0]:', indices[0]) | |||
| print('indices[0][1]:', indices[0][1], indices[0][1] in rows) | |||
| selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) | |||
| # print('selection:', selection) | |||
| selection = torch.nonzero(selection, as_tuple=True)[0] | |||
| # print('selection:', selection) | |||
| indices = indices[:, selection] | |||
| values = values[selection] | |||
| print('"index_select()" took:', time.time() - t) | |||
| t = time.time() | |||
| res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) | |||
| print('_sparse_coo_tensor() took:', time.time() - t) | |||
| return res | |||
| # t = time.time() | |||
| # adj_mat = torch.index_select(adjacency_matrix, 0, rows) | |||
| # print('index_select took:', time.time() - t) | |||
| t = time.time() | |||
| adj_mat = adj_mat.coalesce() | |||
| print('coalesce() took:', time.time() - t) | |||
| indices = adj_mat.indices() | |||
| # print('indices:', indices) | |||
| values = adj_mat.values() | |||
| t = time.time() | |||
| indices[0] = rows[indices[0]] | |||
| print('Lookup took:', time.time() - t) | |||
| t = time.time() | |||
| adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) | |||
| print('_sparse_coo_tensor() took:', time.time() - t) | |||
| return adj_mat | |||
| def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||
| if len(matrices) == 0: | |||
| raise ValueError('The list of matrices must be non-empty') | |||
| if not all(m.is_sparse for m in matrices): | |||
| raise ValueError('All matrices must be sparse') | |||
| if not all(len(m.shape) == 2 for m in matrices): | |||
| raise ValueError('All matrices must be 2D') | |||
| indices = [] | |||
| values = [] | |||
| row_offset = 0 | |||
| col_offset = 0 | |||
| for m in matrices: | |||
| ind = m._indices().clone() | |||
| ind[0] += row_offset | |||
| ind[1] += col_offset | |||
| indices.append(ind) | |||
| values.append(m._values()) | |||
| row_offset += m.shape[0] | |||
| col_offset += m.shape[1] | |||
| indices = torch.cat(indices, dim=1) | |||
| values = torch.cat(values) | |||
| return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||
| def _cat(matrices: List[torch.Tensor]): | |||
| if len(matrices) == 0: | |||
| raise ValueError('Empty list passed to _cat()') | |||
| n = sum(a.is_sparse for a in matrices) | |||
| if n != 0 and n != len(matrices): | |||
| raise ValueError('All matrices must have the same layout (dense or sparse)') | |||
| if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||
| raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||
| if not matrices[0].is_sparse: | |||
| return torch.cat(matrices) | |||
| total_rows = sum(a.shape[0] for a in matrices) | |||
| indices = [] | |||
| values = [] | |||
| row_offset = 0 | |||
| for a in matrices: | |||
| ind = a._indices().clone() | |||
| val = a._values() | |||
| ind[0] += row_offset | |||
| ind = ind.transpose(0, 1) | |||
| indices.append(ind) | |||
| values.append(val) | |||
| row_offset += a.shape[0] | |||
| indices = torch.cat(indices).transpose(0, 1) | |||
| values = torch.cat(values) | |||
| res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
| return res | |||
| @@ -0,0 +1,95 @@ | |||
| from triacontagon.util import \ | |||
| _clear_adjacency_matrix_except_rows, \ | |||
| _sparse_diag_cat, \ | |||
| _equal | |||
| import torch | |||
| import time | |||
| def test_clear_adjacency_matrix_except_rows_01(): | |||
| adj_mat = torch.tensor([ | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 1], | |||
| [1, 0, 1, 0, 0], | |||
| [1, 1, 0, 0, 0] | |||
| ], dtype=torch.uint8).to_sparse() | |||
| adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ]) | |||
| res = _clear_adjacency_matrix_except_rows(adj_mat, | |||
| torch.tensor([1, 3]), 4, 2) | |||
| res = res.to_dense() | |||
| truth = torch.tensor([ | |||
| [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 0, 0, 1, 1, 0, 0, 0, 0, 0], | |||
| [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], | |||
| [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||
| [0, 0, 0, 0, 0, 1, 1, 0, 0, 0] | |||
| ], dtype=torch.uint8) | |||
| print('res:', res) | |||
| assert torch.all(res == truth) | |||
| def test_clear_adjacency_matrix_except_rows_02(): | |||
| adj_mat = torch.rand(6, 10).round().to(torch.uint8) | |||
| t = time.time() | |||
| res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130) | |||
| print('_sparse_diag_cat() took:', time.time() - t) | |||
| t = time.time() | |||
| res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), | |||
| 6, 130) | |||
| print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||
| adj_mat[0] = adj_mat[2] = adj_mat[4] = \ | |||
| torch.zeros(10) | |||
| truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130) | |||
| assert _equal(res, truth).all() | |||
| def test_clear_adjacency_matrix_except_rows_03(): | |||
| adj_mat = torch.rand(6, 10).round().to(torch.uint8) | |||
| t = time.time() | |||
| res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||
| print('_sparse_diag_cat() took:', time.time() - t) | |||
| t = time.time() | |||
| res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), | |||
| 6, 1300) | |||
| print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||
| adj_mat[0] = adj_mat[2] = adj_mat[4] = \ | |||
| torch.zeros(10) | |||
| truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||
| assert _equal(res, truth).all() | |||
| def test_clear_adjacency_matrix_except_rows_04(): | |||
| adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) | |||
| t = time.time() | |||
| res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||
| print('_sparse_diag_cat() took:', time.time() - t) | |||
| t = time.time() | |||
| res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), | |||
| 2000, 1300) | |||
| print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||
| adj_mat[0] = adj_mat[2] = adj_mat[4] = \ | |||
| torch.zeros(2000) | |||
| adj_mat[6:] = torch.zeros(2000) | |||
| truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||
| assert _equal(res, truth).all() | |||