@@ -4,206 +4,68 @@ | |||||
# | # | ||||
from collections import defaultdict | |||||
from dataclasses import dataclass, field | |||||
import torch | |||||
from typing import List, \ | |||||
Dict, \ | |||||
from dataclasses import dataclass | |||||
from typing import Callable, \ | |||||
Tuple, \ | Tuple, \ | ||||
Any, \ | |||||
Type | |||||
from .decode import DEDICOMDecoder, \ | |||||
BilinearDecoder | |||||
import numpy as np | |||||
def _equal(x: torch.Tensor, y: torch.Tensor): | |||||
if x.is_sparse ^ y.is_sparse: | |||||
raise ValueError('Cannot mix sparse and dense tensors') | |||||
if not x.is_sparse: | |||||
return (x == y) | |||||
return ((x - y).coalesce().values() == 0) | |||||
List | |||||
import types | |||||
from .util import _nonzero_sum | |||||
@dataclass | @dataclass | ||||
class NodeType(object): | |||||
name: str | |||||
count: int | |||||
class DecodingMatrices(object): | |||||
global_interaction: torch.Tensor | |||||
local_variation: torch.Tensor | |||||
@dataclass | @dataclass | ||||
class RelationTypeBase(object): | |||||
class VertexType(object): | |||||
name: str | name: str | ||||
node_type_row: int | |||||
node_type_column: int | |||||
adjacency_matrix: torch.Tensor | |||||
adjacency_matrix_backward: torch.Tensor | |||||
@dataclass | |||||
class RelationType(RelationTypeBase): | |||||
pass | |||||
count: int | |||||
@dataclass | @dataclass | ||||
class RelationFamilyBase(object): | |||||
data: 'Data' | |||||
class EdgeType(object): | |||||
name: str | name: str | ||||
node_type_row: int | |||||
node_type_column: int | |||||
is_symmetric: bool | |||||
decoder_class: Type | |||||
@dataclass | |||||
class RelationFamily(RelationFamilyBase): | |||||
relation_types: List[RelationType] = None | |||||
def __post_init__(self) -> None: | |||||
if not self.is_symmetric and \ | |||||
self.decoder_class != DEDICOMDecoder and \ | |||||
self.decoder_class != BilinearDecoder: | |||||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
self.relation_types = [] | |||||
def add_relation_type(self, | |||||
name: str, adjacency_matrix: torch.Tensor, | |||||
adjacency_matrix_backward: torch.Tensor = None) -> None: | |||||
name = str(name) | |||||
node_type_row = self.node_type_row | |||||
node_type_column = self.node_type_column | |||||
if adjacency_matrix is None and adjacency_matrix_backward is None: | |||||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||||
if adjacency_matrix is not None and \ | |||||
not isinstance(adjacency_matrix, torch.Tensor): | |||||
raise ValueError('adjacency_matrix must be a torch.Tensor') | |||||
if adjacency_matrix_backward is not None \ | |||||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||||
if adjacency_matrix is not None and \ | |||||
adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
self.data.node_types[node_type_column].count): | |||||
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||||
if adjacency_matrix_backward is not None and \ | |||||
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||||
self.data.node_types[node_type_row].count): | |||||
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||||
if node_type_row == node_type_column and \ | |||||
adjacency_matrix_backward is not None: | |||||
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||||
if self.is_symmetric and adjacency_matrix_backward is not None: | |||||
raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') | |||||
if self.is_symmetric and node_type_row == node_type_column and \ | |||||
not torch.all(_equal(adjacency_matrix, | |||||
adjacency_matrix.transpose(0, 1))): | |||||
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||||
if not self.is_symmetric and node_type_row != node_type_column and \ | |||||
adjacency_matrix_backward is None: | |||||
raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') | |||||
if self.is_symmetric and node_type_row != node_type_column: | |||||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
self.relation_types.append(RelationType(name, | |||||
node_type_row, node_type_column, | |||||
adjacency_matrix, adjacency_matrix_backward)) | |||||
def node_name(self, index): | |||||
return self.data.node_types[index].name | |||||
def __repr__(self): | |||||
s = 'Relation family %s' % self.name | |||||
for r in self.relation_types: | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
if (r.adjacency_matrix is not None \ | |||||
and r.adjacency_matrix_backward is not None) \ | |||||
or self.node_type_row == self.node_type_column \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
self.node_name(self.node_type_column))) | |||||
return s | |||||
def repr_indented(self): | |||||
s = ' - %s' % self.name | |||||
for r in self.relation_types: | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
if (r.adjacency_matrix is not None \ | |||||
and r.adjacency_matrix_backward is not None) \ | |||||
or self.node_type_row == self.node_type_column \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
self.node_name(self.node_type_column))) | |||||
return s | |||||
vertex_type_row: int | |||||
vertex_type_column: int | |||||
adjacency_matrices: List[torch.Tensor] | |||||
decoder_factory: Callable[[], DecodingMatrices] | |||||
total_connectivity: torch.Tensor | |||||
class Data(object): | class Data(object): | ||||
node_types: List[NodeType] | |||||
relation_families: List[RelationFamily] | |||||
vertex_types: List[VertexType] | |||||
edge_types: List[EdgeType] | |||||
def __init__(self) -> None: | def __init__(self) -> None: | ||||
self.node_types = [] | |||||
self.relation_families = [] | |||||
self.vertex_types = [] | |||||
self.edge_types = {} | |||||
def add_node_type(self, name: str, count: int) -> None: | |||||
def add_vertex_type(self, name: str, count: int) -> None: | |||||
name = str(name) | name = str(name) | ||||
count = int(count) | count = int(count) | ||||
if not name: | if not name: | ||||
raise ValueError('You must provide a non-empty node type name') | |||||
raise ValueError('You must provide a non-empty vertex type name') | |||||
if count <= 0: | if count <= 0: | ||||
raise ValueError('You must provide a positive node count') | |||||
self.node_types.append(NodeType(name, count)) | |||||
raise ValueError('You must provide a positive vertex count') | |||||
self.vertex_types.append(VertexType(name, count)) | |||||
def add_relation_family(self, name: str, node_type_row: int, | |||||
node_type_column: int, is_symmetric: bool, | |||||
decoder_class: Type = DEDICOMDecoder): | |||||
def add_edge_type(self, name: str, | |||||
vertex_type_row: int, vertex_type_column: int, | |||||
adjacency_matrices: List[torch.Tensor], | |||||
decoder_factory: Callable[[], DecodingMatrices]) -> None: | |||||
name = str(name) | name = str(name) | ||||
node_type_row = int(node_type_row) | |||||
node_type_column = int(node_type_column) | |||||
is_symmetric = bool(is_symmetric) | |||||
if node_type_row < 0 or node_type_row >= len(self.node_types): | |||||
raise ValueError('node_type_row outside of the valid range of node types') | |||||
if node_type_column < 0 or node_type_column >= len(self.node_types): | |||||
raise ValueError('node_type_column outside of the valid range of node types') | |||||
fam = RelationFamily(self, name, node_type_row, node_type_column, | |||||
is_symmetric, decoder_class) | |||||
self.relation_families.append(fam) | |||||
return fam | |||||
def __repr__(self): | |||||
n = len(self.node_types) | |||||
if n == 0: | |||||
return 'Empty Icosagon Data' | |||||
s = '' | |||||
s += 'Icosagon Data with:\n' | |||||
s += '- ' + str(n) + ' node type(s):\n' | |||||
for nt in self.node_types: | |||||
s += ' - ' + nt.name + '\n' | |||||
if len(self.relation_families) == 0: | |||||
s += '- No relation families\n' | |||||
return s.strip() | |||||
s += '- %d relation families:\n' % len(self.relation_families) | |||||
for fam in self.relation_families: | |||||
s += fam.repr_indented() + '\n' | |||||
return s.strip() | |||||
vertex_type_row = int(vertex_type_row) | |||||
vertex_type_column = int(vertex_type_column) | |||||
if not isinstance(adjacency_matrices, list): | |||||
raise TypeError('adjacency_matrices must be a list of tensors') | |||||
if not isinstance(decoder_factory, types.FunctionType): | |||||
raise TypeError('decoder_factory must be a function') | |||||
if (vertex_type_row, vertex_type_column) in self.edge_types: | |||||
raise KeyError('Edge type for given combination of row and column already exists') | |||||
total_connectivity = _nonzero_sum(adjacency_matrices) | |||||
self.edges_types[vertex_type_row, vertex_type_column] = \ | |||||
VertexType(name, vertex_type_row, vertex_type_column, | |||||
adjacency_matrices, decoder_factory, total_connectivity) |
@@ -7,117 +7,47 @@ | |||||
import torch | import torch | ||||
from .weights import init_glorot | from .weights import init_glorot | ||||
from .dropout import dropout | from .dropout import dropout | ||||
from typing import Tuple, \ | |||||
List | |||||
class DEDICOMDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
global_interaction = init_glorot(input_dim, input_dim) | |||||
local_variation = [ | |||||
torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
return (global_interaction, local_variation) | |||||
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||||
self.local_variation = torch.nn.ParameterList([ | |||||
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
for _ in range(num_relation_types) | |||||
]) | |||||
def forward(self, inputs_row, inputs_col, relation_index): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
def dist_mult_decoder(input_dim: int, num_relation_types: int) -> | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||||
relation = torch.diag(self.local_variation[relation_index]) | |||||
global_interaction = torch.eye(input_dim, input_dim) | |||||
local_variation = [ | |||||
torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
return (global_interaction, local_variation) | |||||
product1 = torch.mm(inputs_row, relation) | |||||
product2 = torch.mm(product1, self.global_interaction) | |||||
product3 = torch.mm(product2, relation) | |||||
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||||
global_interaction = torch.eye(input_dim, input_dim) | |||||
local_variation = [ | |||||
init_glorot(input_dim, input_dim) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
return (global_interaction, local_variation) | |||||
class DistMultDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
def inner_product_decoder(input_dim: int, num_relation_types: int) -> | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||||
self.relation = torch.nn.ParameterList([ | |||||
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
for _ in range(num_relation_types) | |||||
]) | |||||
def forward(self, inputs_row, inputs_col, relation_index): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
relation = torch.diag(self.relation[relation_index]) | |||||
intermediate_product = torch.mm(inputs_row, relation) | |||||
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
class BilinearDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.relation = torch.nn.ParameterList([ | |||||
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||||
for _ in range(num_relation_types) | |||||
]) | |||||
def forward(self, inputs_row, inputs_col, relation_index): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) | |||||
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
class InnerProductDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
def forward(self, inputs_row, inputs_col, _): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
global_interaction = torch.eye(input_dim, input_dim) | |||||
local_variation = torch.eye(input_dim, input_dim) | |||||
local_variation = [ local_variation ] * num_relation_types | |||||
return (global_interaction, local_variation) |
@@ -0,0 +1,129 @@ | |||||
from .data import Data, \ | |||||
EdgeType | |||||
import torch | |||||
from dataclasses import dataclass | |||||
from .weights import init_glorot | |||||
import types | |||||
from typing import List, \ | |||||
Dict, \ | |||||
Callable | |||||
from .util import _sparse_coo_tensor | |||||
@dataclass | |||||
class TrainingBatch(object): | |||||
vertex_type_row: int | |||||
vertex_type_column: int | |||||
relation_type_index: int | |||||
edges: torch.Tensor | |||||
class Model(torch.nn.Module): | |||||
def __init__(self, data: Data, layer_dimensions: List[int], | |||||
keep_prob: float, | |||||
conv_activation: Callable[[torch.Tensor], torch.Tensor], | |||||
dec_activation: Callable[[torch.Tensor], torch.Tensor], | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
if not isinstance(data, Data): | |||||
raise TypeError('data must be an instance of Data') | |||||
if not isinstance(conv_activation, types.FunctionType): | |||||
raise TypeError('conv_activation must be a function') | |||||
if not isinstance(dec_activation, types.FunctionType): | |||||
raise TypeError('dec_activation must be a function') | |||||
self.data = data | |||||
self.layer_dimensions = list(layer_dimensions) | |||||
self.keep_prob = float(keep_prob) | |||||
self.conv_activation = conv_activation | |||||
self.dec_activation = dec_activation | |||||
self.conv_weights = None | |||||
self.dec_weights = None | |||||
self.build() | |||||
def build(self) -> None: | |||||
self.conv_weights = torch.nn.ParameterDict() | |||||
for i in range(len(self.layer_dimensions) - 1): | |||||
in_dimension = self.layer_dimensions[i] | |||||
out_dimension = self.layer_dimensions[i + 1] | |||||
for _, et in self.data.edge_types.items(): | |||||
weight = init_glorot(in_dimension, out_dimension) | |||||
self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \ | |||||
torch.nn.Parameter(weight) | |||||
self.dec_weights = torch.nn.ParameterDict() | |||||
for _, et in self.data.edge_types.items(): | |||||
global_interaction, local_variation = \ | |||||
et.decoder_factory(self.layer_dimensions[-1], | |||||
len(et.adjacency_matrices)) | |||||
self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \ | |||||
torch.nn.ParameterList([ | |||||
torch.nn.Parameter(global_interaction), | |||||
torch.nn.Parameter(local_variation) | |||||
]) | |||||
def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, | |||||
rows: torch.Tensor) -> torch.Tensor: | |||||
adj_mat = adjacency_matrix.coalesce() | |||||
adj_mat = torch.index_select(adj_mat, 0, rows) | |||||
adj_mat = adj_mat.coalesce() | |||||
indices = adj_mat.indices() | |||||
indices[0] = rows | |||||
adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) | |||||
def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | |||||
batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | |||||
col = batch.vertex_type_column | |||||
rows = batch.edges[:, 0] | |||||
columns = batch.edges[:, 1].sum(dim=0).flatten() | |||||
columns = torch.nonzero(columns) | |||||
for i in range(len(self.layer_dimensions) - 1): | |||||
columns = | |||||
def temporary_adjacency_matrices(self, batch: TrainingBatch) -> | |||||
Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||||
col = batch.vertex_type_column | |||||
batch.edges[:, 1] | |||||
res = {} | |||||
for _, et in self.data.edge_types.items(): | |||||
sum_nonzero = _nonzero_sum(et.adjacency_matrices) | |||||
res[et.vertex_type_row, et.vertex_type_column] = \ | |||||
[ self.temporary_adjacency_matrix(adj_mat, batch, | |||||
et.total_connectivity) \ | |||||
for adj_mat in et.adjacency_matrices ] | |||||
return res | |||||
def forward(self, initial_repr: List[torch.Tensor], | |||||
batch: TrainingBatch) -> torch.Tensor: | |||||
if not isinstance(initial_repr, list): | |||||
raise TypeError('initial_repr must be a list') | |||||
if len(initial_repr) != len(self.data.vertex_types): | |||||
raise ValueError('initial_repr must contain representations for all vertex types') | |||||
if not isinstance(batch, TrainingBatch): | |||||
raise TypeError('batch must be an instance of TrainingBatch') | |||||
adj_matrices = self.temporary_adjacency_matrices(batch) | |||||
row_vertices = initial_repr[batch.vertex_type_row] | |||||
column_vertices = initial_repr[batch.vertex_type_column] |
@@ -0,0 +1,174 @@ | |||||
import torch | |||||
from typing import List, \ | |||||
Set | |||||
import time | |||||
def _equal(x: torch.Tensor, y: torch.Tensor): | |||||
if x.is_sparse ^ y.is_sparse: | |||||
raise ValueError('Cannot mix sparse and dense tensors') | |||||
if not x.is_sparse: | |||||
return (x == y) | |||||
return ((x - y).coalesce().values() == 0) | |||||
def _sparse_coo_tensor(indices, values, size): | |||||
ctor = { torch.float32: torch.sparse.FloatTensor, | |||||
torch.float32: torch.sparse.DoubleTensor, | |||||
torch.uint8: torch.sparse.ByteTensor, | |||||
torch.long: torch.sparse.LongTensor, | |||||
torch.int: torch.sparse.IntTensor, | |||||
torch.short: torch.sparse.ShortTensor, | |||||
torch.bool: torch.sparse.ByteTensor }[values.dtype] | |||||
return ctor(indices, values, size) | |||||
def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): | |||||
if len(adjacency_matrices) == 0: | |||||
raise ValueError('adjacency_matrices must be non-empty') | |||||
if not all([x.is_sparse for x in adjacency_matrices]): | |||||
raise ValueError('All adjacency matrices must be sparse') | |||||
indices = [ x.indices() for x in adjacency_matrices ] | |||||
indices = torch.cat(indices, dim=1) | |||||
values = torch.ones(indices.shape[1]) | |||||
res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape) | |||||
res = res.coalesce() | |||||
indices = res.indices() | |||||
res = _sparse_coo_tensor(indices, | |||||
torch.ones(indices.shape[1], dtype=torch.uint8)) | |||||
return res | |||||
def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, | |||||
rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor: | |||||
if not adjacency_matrix.is_sparse: | |||||
raise ValueError('adjacency_matrix must be sparse') | |||||
if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types: | |||||
raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types') | |||||
t = time.time() | |||||
rows = [ rows + row_vertex_count * i \ | |||||
for i in range(num_relation_types) ] | |||||
print('rows took:', time.time() - t) | |||||
t = time.time() | |||||
rows = torch.cat(rows) | |||||
print('cat took:', time.time() - t) | |||||
# print('rows:', rows) | |||||
rows = set(rows.tolist()) | |||||
# print('rows:', rows) | |||||
t = time.time() | |||||
adj_mat = adjacency_matrix.coalesce() | |||||
indices = adj_mat.indices() | |||||
values = adj_mat.values() | |||||
print('indices[0]:', indices[0]) | |||||
print('indices[0][1]:', indices[0][1], indices[0][1] in rows) | |||||
selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) | |||||
# print('selection:', selection) | |||||
selection = torch.nonzero(selection, as_tuple=True)[0] | |||||
# print('selection:', selection) | |||||
indices = indices[:, selection] | |||||
values = values[selection] | |||||
print('"index_select()" took:', time.time() - t) | |||||
t = time.time() | |||||
res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) | |||||
print('_sparse_coo_tensor() took:', time.time() - t) | |||||
return res | |||||
# t = time.time() | |||||
# adj_mat = torch.index_select(adjacency_matrix, 0, rows) | |||||
# print('index_select took:', time.time() - t) | |||||
t = time.time() | |||||
adj_mat = adj_mat.coalesce() | |||||
print('coalesce() took:', time.time() - t) | |||||
indices = adj_mat.indices() | |||||
# print('indices:', indices) | |||||
values = adj_mat.values() | |||||
t = time.time() | |||||
indices[0] = rows[indices[0]] | |||||
print('Lookup took:', time.time() - t) | |||||
t = time.time() | |||||
adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) | |||||
print('_sparse_coo_tensor() took:', time.time() - t) | |||||
return adj_mat | |||||
def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||||
if len(matrices) == 0: | |||||
raise ValueError('The list of matrices must be non-empty') | |||||
if not all(m.is_sparse for m in matrices): | |||||
raise ValueError('All matrices must be sparse') | |||||
if not all(len(m.shape) == 2 for m in matrices): | |||||
raise ValueError('All matrices must be 2D') | |||||
indices = [] | |||||
values = [] | |||||
row_offset = 0 | |||||
col_offset = 0 | |||||
for m in matrices: | |||||
ind = m._indices().clone() | |||||
ind[0] += row_offset | |||||
ind[1] += col_offset | |||||
indices.append(ind) | |||||
values.append(m._values()) | |||||
row_offset += m.shape[0] | |||||
col_offset += m.shape[1] | |||||
indices = torch.cat(indices, dim=1) | |||||
values = torch.cat(values) | |||||
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||||
def _cat(matrices: List[torch.Tensor]): | |||||
if len(matrices) == 0: | |||||
raise ValueError('Empty list passed to _cat()') | |||||
n = sum(a.is_sparse for a in matrices) | |||||
if n != 0 and n != len(matrices): | |||||
raise ValueError('All matrices must have the same layout (dense or sparse)') | |||||
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||||
raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||||
if not matrices[0].is_sparse: | |||||
return torch.cat(matrices) | |||||
total_rows = sum(a.shape[0] for a in matrices) | |||||
indices = [] | |||||
values = [] | |||||
row_offset = 0 | |||||
for a in matrices: | |||||
ind = a._indices().clone() | |||||
val = a._values() | |||||
ind[0] += row_offset | |||||
ind = ind.transpose(0, 1) | |||||
indices.append(ind) | |||||
values.append(val) | |||||
row_offset += a.shape[0] | |||||
indices = torch.cat(indices).transpose(0, 1) | |||||
values = torch.cat(values) | |||||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||||
return res |
@@ -0,0 +1,95 @@ | |||||
from triacontagon.util import \ | |||||
_clear_adjacency_matrix_except_rows, \ | |||||
_sparse_diag_cat, \ | |||||
_equal | |||||
import torch | |||||
import time | |||||
def test_clear_adjacency_matrix_except_rows_01(): | |||||
adj_mat = torch.tensor([ | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 1], | |||||
[1, 0, 1, 0, 0], | |||||
[1, 1, 0, 0, 0] | |||||
], dtype=torch.uint8).to_sparse() | |||||
adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ]) | |||||
res = _clear_adjacency_matrix_except_rows(adj_mat, | |||||
torch.tensor([1, 3]), 4, 2) | |||||
res = res.to_dense() | |||||
truth = torch.tensor([ | |||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0], | |||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1], | |||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |||||
[0, 0, 0, 0, 0, 1, 1, 0, 0, 0] | |||||
], dtype=torch.uint8) | |||||
print('res:', res) | |||||
assert torch.all(res == truth) | |||||
def test_clear_adjacency_matrix_except_rows_02(): | |||||
adj_mat = torch.rand(6, 10).round().to(torch.uint8) | |||||
t = time.time() | |||||
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130) | |||||
print('_sparse_diag_cat() took:', time.time() - t) | |||||
t = time.time() | |||||
res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), | |||||
6, 130) | |||||
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||||
adj_mat[0] = adj_mat[2] = adj_mat[4] = \ | |||||
torch.zeros(10) | |||||
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130) | |||||
assert _equal(res, truth).all() | |||||
def test_clear_adjacency_matrix_except_rows_03(): | |||||
adj_mat = torch.rand(6, 10).round().to(torch.uint8) | |||||
t = time.time() | |||||
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||||
print('_sparse_diag_cat() took:', time.time() - t) | |||||
t = time.time() | |||||
res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), | |||||
6, 1300) | |||||
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||||
adj_mat[0] = adj_mat[2] = adj_mat[4] = \ | |||||
torch.zeros(10) | |||||
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||||
assert _equal(res, truth).all() | |||||
def test_clear_adjacency_matrix_except_rows_04(): | |||||
adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) | |||||
t = time.time() | |||||
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||||
print('_sparse_diag_cat() took:', time.time() - t) | |||||
t = time.time() | |||||
res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]), | |||||
2000, 1300) | |||||
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||||
adj_mat[0] = adj_mat[2] = adj_mat[4] = \ | |||||
torch.zeros(2000) | |||||
adj_mat[6:] = torch.zeros(2000) | |||||
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||||
assert _equal(res, truth).all() |