IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Преглед изворни кода

Work on triacontagon.

master
Stanislaw Adaszewski пре 4 година
родитељ
комит
4f953fd203
6 измењених фајлова са 470 додато и 280 уклоњено
  1. +0
    -0
      src/triacontagon/__init__.py
  2. +40
    -178
      src/triacontagon/data.py
  3. +32
    -102
      src/triacontagon/decode.py
  4. +129
    -0
      src/triacontagon/model.py
  5. +174
    -0
      src/triacontagon/util.py
  6. +95
    -0
      tests/triacontagon/test_util.py

+ 0
- 0
src/triacontagon/__init__.py Прегледај датотеку


+ 40
- 178
src/triacontagon/data.py Прегледај датотеку

@@ -4,206 +4,68 @@
#
from collections import defaultdict
from dataclasses import dataclass, field
import torch
from typing import List, \
Dict, \
from dataclasses import dataclass
from typing import Callable, \
Tuple, \
Any, \
Type
from .decode import DEDICOMDecoder, \
BilinearDecoder
import numpy as np
def _equal(x: torch.Tensor, y: torch.Tensor):
if x.is_sparse ^ y.is_sparse:
raise ValueError('Cannot mix sparse and dense tensors')
if not x.is_sparse:
return (x == y)
return ((x - y).coalesce().values() == 0)
List
import types
from .util import _nonzero_sum
@dataclass
class NodeType(object):
name: str
count: int
class DecodingMatrices(object):
global_interaction: torch.Tensor
local_variation: torch.Tensor
@dataclass
class RelationTypeBase(object):
class VertexType(object):
name: str
node_type_row: int
node_type_column: int
adjacency_matrix: torch.Tensor
adjacency_matrix_backward: torch.Tensor
@dataclass
class RelationType(RelationTypeBase):
pass
count: int
@dataclass
class RelationFamilyBase(object):
data: 'Data'
class EdgeType(object):
name: str
node_type_row: int
node_type_column: int
is_symmetric: bool
decoder_class: Type
@dataclass
class RelationFamily(RelationFamilyBase):
relation_types: List[RelationType] = None
def __post_init__(self) -> None:
if not self.is_symmetric and \
self.decoder_class != DEDICOMDecoder and \
self.decoder_class != BilinearDecoder:
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
self.relation_types = []
def add_relation_type(self,
name: str, adjacency_matrix: torch.Tensor,
adjacency_matrix_backward: torch.Tensor = None) -> None:
name = str(name)
node_type_row = self.node_type_row
node_type_column = self.node_type_column
if adjacency_matrix is None and adjacency_matrix_backward is None:
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
if adjacency_matrix is not None and \
not isinstance(adjacency_matrix, torch.Tensor):
raise ValueError('adjacency_matrix must be a torch.Tensor')
if adjacency_matrix_backward is not None \
and not isinstance(adjacency_matrix_backward, torch.Tensor):
raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
if adjacency_matrix is not None and \
adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
self.data.node_types[node_type_column].count):
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
if adjacency_matrix_backward is not None and \
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
self.data.node_types[node_type_row].count):
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)')
if node_type_row == node_type_column and \
adjacency_matrix_backward is not None:
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
if self.is_symmetric and adjacency_matrix_backward is not None:
raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
if self.is_symmetric and node_type_row == node_type_column and \
not torch.all(_equal(adjacency_matrix,
adjacency_matrix.transpose(0, 1))):
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
if not self.is_symmetric and node_type_row != node_type_column and \
adjacency_matrix_backward is None:
raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None')
if self.is_symmetric and node_type_row != node_type_column:
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
self.relation_types.append(RelationType(name,
node_type_row, node_type_column,
adjacency_matrix, adjacency_matrix_backward))
def node_name(self, index):
return self.data.node_types[index].name
def __repr__(self):
s = 'Relation family %s' % self.name
for r in self.relation_types:
s += '\n - %s%s' % (r.name, ' (two-way)' \
if (r.adjacency_matrix is not None \
and r.adjacency_matrix_backward is not None) \
or self.node_type_row == self.node_type_column \
else '%s <- %s' % (self.node_name(self.node_type_row),
self.node_name(self.node_type_column)))
return s
def repr_indented(self):
s = ' - %s' % self.name
for r in self.relation_types:
s += '\n - %s%s' % (r.name, ' (two-way)' \
if (r.adjacency_matrix is not None \
and r.adjacency_matrix_backward is not None) \
or self.node_type_row == self.node_type_column \
else '%s <- %s' % (self.node_name(self.node_type_row),
self.node_name(self.node_type_column)))
return s
vertex_type_row: int
vertex_type_column: int
adjacency_matrices: List[torch.Tensor]
decoder_factory: Callable[[], DecodingMatrices]
total_connectivity: torch.Tensor
class Data(object):
node_types: List[NodeType]
relation_families: List[RelationFamily]
vertex_types: List[VertexType]
edge_types: List[EdgeType]
def __init__(self) -> None:
self.node_types = []
self.relation_families = []
self.vertex_types = []
self.edge_types = {}
def add_node_type(self, name: str, count: int) -> None:
def add_vertex_type(self, name: str, count: int) -> None:
name = str(name)
count = int(count)
if not name:
raise ValueError('You must provide a non-empty node type name')
raise ValueError('You must provide a non-empty vertex type name')
if count <= 0:
raise ValueError('You must provide a positive node count')
self.node_types.append(NodeType(name, count))
raise ValueError('You must provide a positive vertex count')
self.vertex_types.append(VertexType(name, count))
def add_relation_family(self, name: str, node_type_row: int,
node_type_column: int, is_symmetric: bool,
decoder_class: Type = DEDICOMDecoder):
def add_edge_type(self, name: str,
vertex_type_row: int, vertex_type_column: int,
adjacency_matrices: List[torch.Tensor],
decoder_factory: Callable[[], DecodingMatrices]) -> None:
name = str(name)
node_type_row = int(node_type_row)
node_type_column = int(node_type_column)
is_symmetric = bool(is_symmetric)
if node_type_row < 0 or node_type_row >= len(self.node_types):
raise ValueError('node_type_row outside of the valid range of node types')
if node_type_column < 0 or node_type_column >= len(self.node_types):
raise ValueError('node_type_column outside of the valid range of node types')
fam = RelationFamily(self, name, node_type_row, node_type_column,
is_symmetric, decoder_class)
self.relation_families.append(fam)
return fam
def __repr__(self):
n = len(self.node_types)
if n == 0:
return 'Empty Icosagon Data'
s = ''
s += 'Icosagon Data with:\n'
s += '- ' + str(n) + ' node type(s):\n'
for nt in self.node_types:
s += ' - ' + nt.name + '\n'
if len(self.relation_families) == 0:
s += '- No relation families\n'
return s.strip()
s += '- %d relation families:\n' % len(self.relation_families)
for fam in self.relation_families:
s += fam.repr_indented() + '\n'
return s.strip()
vertex_type_row = int(vertex_type_row)
vertex_type_column = int(vertex_type_column)
if not isinstance(adjacency_matrices, list):
raise TypeError('adjacency_matrices must be a list of tensors')
if not isinstance(decoder_factory, types.FunctionType):
raise TypeError('decoder_factory must be a function')
if (vertex_type_row, vertex_type_column) in self.edge_types:
raise KeyError('Edge type for given combination of row and column already exists')
total_connectivity = _nonzero_sum(adjacency_matrices)
self.edges_types[vertex_type_row, vertex_type_column] = \
VertexType(name, vertex_type_row, vertex_type_column,
adjacency_matrices, decoder_factory, total_connectivity)

+ 32
- 102
src/triacontagon/decode.py Прегледај датотеку

@@ -7,117 +7,47 @@
import torch
from .weights import init_glorot
from .dropout import dropout
from typing import Tuple, \
List
class DEDICOMDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
def dedicom_decoder(input_dim: int, num_relation_types: int) ->
Tuple[torch.Tensor, List[torch.Tensor]]:
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.keep_prob = keep_prob
self.activation = activation
global_interaction = init_glorot(input_dim, input_dim)
local_variation = [
torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
return (global_interaction, local_variation)
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
self.local_variation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
def dist_mult_decoder(input_dim: int, num_relation_types: int) ->
Tuple[torch.Tensor, List[torch.Tensor]]:
relation = torch.diag(self.local_variation[relation_index])
global_interaction = torch.eye(input_dim, input_dim)
local_variation = [
torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \
for _ in range(num_relation_types)
]
return (global_interaction, local_variation)
product1 = torch.mm(inputs_row, relation)
product2 = torch.mm(product1, self.global_interaction)
product3 = torch.mm(product2, relation)
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
return self.activation(rec)
def bilinear_decoder(input_dim: int, num_relation_types: int) ->
Tuple[torch.Tensor, List[torch.Tensor]]:
global_interaction = torch.eye(input_dim, input_dim)
local_variation = [
init_glorot(input_dim, input_dim) \
for _ in range(num_relation_types)
]
return (global_interaction, local_variation)
class DistMultDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.keep_prob = keep_prob
self.activation = activation
def inner_product_decoder(input_dim: int, num_relation_types: int) ->
Tuple[torch.Tensor, List[torch.Tensor]]:
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
relation = torch.diag(self.relation[relation_index])
intermediate_product = torch.mm(inputs_row, relation)
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
return self.activation(rec)
class BilinearDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.keep_prob = keep_prob
self.activation = activation
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
for _ in range(num_relation_types)
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
intermediate_product = torch.mm(inputs_row, self.relation[relation_index])
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
return self.activation(rec)
class InnerProductDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.keep_prob = keep_prob
self.activation = activation
def forward(self, inputs_row, inputs_col, _):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
return self.activation(rec)
global_interaction = torch.eye(input_dim, input_dim)
local_variation = torch.eye(input_dim, input_dim)
local_variation = [ local_variation ] * num_relation_types
return (global_interaction, local_variation)

+ 129
- 0
src/triacontagon/model.py Прегледај датотеку

@@ -0,0 +1,129 @@
from .data import Data, \
EdgeType
import torch
from dataclasses import dataclass
from .weights import init_glorot
import types
from typing import List, \
Dict, \
Callable
from .util import _sparse_coo_tensor
@dataclass
class TrainingBatch(object):
vertex_type_row: int
vertex_type_column: int
relation_type_index: int
edges: torch.Tensor
class Model(torch.nn.Module):
def __init__(self, data: Data, layer_dimensions: List[int],
keep_prob: float,
conv_activation: Callable[[torch.Tensor], torch.Tensor],
dec_activation: Callable[[torch.Tensor], torch.Tensor],
**kwargs) -> None:
super().__init__(**kwargs)
if not isinstance(data, Data):
raise TypeError('data must be an instance of Data')
if not isinstance(conv_activation, types.FunctionType):
raise TypeError('conv_activation must be a function')
if not isinstance(dec_activation, types.FunctionType):
raise TypeError('dec_activation must be a function')
self.data = data
self.layer_dimensions = list(layer_dimensions)
self.keep_prob = float(keep_prob)
self.conv_activation = conv_activation
self.dec_activation = dec_activation
self.conv_weights = None
self.dec_weights = None
self.build()
def build(self) -> None:
self.conv_weights = torch.nn.ParameterDict()
for i in range(len(self.layer_dimensions) - 1):
in_dimension = self.layer_dimensions[i]
out_dimension = self.layer_dimensions[i + 1]
for _, et in self.data.edge_types.items():
weight = init_glorot(in_dimension, out_dimension)
self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \
torch.nn.Parameter(weight)
self.dec_weights = torch.nn.ParameterDict()
for _, et in self.data.edge_types.items():
global_interaction, local_variation = \
et.decoder_factory(self.layer_dimensions[-1],
len(et.adjacency_matrices))
self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \
torch.nn.ParameterList([
torch.nn.Parameter(global_interaction),
torch.nn.Parameter(local_variation)
])
def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor,
rows: torch.Tensor) -> torch.Tensor:
adj_mat = adjacency_matrix.coalesce()
adj_mat = torch.index_select(adj_mat, 0, rows)
adj_mat = adj_mat.coalesce()
indices = adj_mat.indices()
indices[0] = rows
adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape)
def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
col = batch.vertex_type_column
rows = batch.edges[:, 0]
columns = batch.edges[:, 1].sum(dim=0).flatten()
columns = torch.nonzero(columns)
for i in range(len(self.layer_dimensions) - 1):
columns =
def temporary_adjacency_matrices(self, batch: TrainingBatch) ->
Dict[Tuple[int, int], List[List[torch.Tensor]]]:
col = batch.vertex_type_column
batch.edges[:, 1]
res = {}
for _, et in self.data.edge_types.items():
sum_nonzero = _nonzero_sum(et.adjacency_matrices)
res[et.vertex_type_row, et.vertex_type_column] = \
[ self.temporary_adjacency_matrix(adj_mat, batch,
et.total_connectivity) \
for adj_mat in et.adjacency_matrices ]
return res
def forward(self, initial_repr: List[torch.Tensor],
batch: TrainingBatch) -> torch.Tensor:
if not isinstance(initial_repr, list):
raise TypeError('initial_repr must be a list')
if len(initial_repr) != len(self.data.vertex_types):
raise ValueError('initial_repr must contain representations for all vertex types')
if not isinstance(batch, TrainingBatch):
raise TypeError('batch must be an instance of TrainingBatch')
adj_matrices = self.temporary_adjacency_matrices(batch)
row_vertices = initial_repr[batch.vertex_type_row]
column_vertices = initial_repr[batch.vertex_type_column]

+ 174
- 0
src/triacontagon/util.py Прегледај датотеку

@@ -0,0 +1,174 @@
import torch
from typing import List, \
Set
import time
def _equal(x: torch.Tensor, y: torch.Tensor):
if x.is_sparse ^ y.is_sparse:
raise ValueError('Cannot mix sparse and dense tensors')
if not x.is_sparse:
return (x == y)
return ((x - y).coalesce().values() == 0)
def _sparse_coo_tensor(indices, values, size):
ctor = { torch.float32: torch.sparse.FloatTensor,
torch.float32: torch.sparse.DoubleTensor,
torch.uint8: torch.sparse.ByteTensor,
torch.long: torch.sparse.LongTensor,
torch.int: torch.sparse.IntTensor,
torch.short: torch.sparse.ShortTensor,
torch.bool: torch.sparse.ByteTensor }[values.dtype]
return ctor(indices, values, size)
def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
if len(adjacency_matrices) == 0:
raise ValueError('adjacency_matrices must be non-empty')
if not all([x.is_sparse for x in adjacency_matrices]):
raise ValueError('All adjacency matrices must be sparse')
indices = [ x.indices() for x in adjacency_matrices ]
indices = torch.cat(indices, dim=1)
values = torch.ones(indices.shape[1])
res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
res = res.coalesce()
indices = res.indices()
res = _sparse_coo_tensor(indices,
torch.ones(indices.shape[1], dtype=torch.uint8))
return res
def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
if not adjacency_matrix.is_sparse:
raise ValueError('adjacency_matrix must be sparse')
if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
t = time.time()
rows = [ rows + row_vertex_count * i \
for i in range(num_relation_types) ]
print('rows took:', time.time() - t)
t = time.time()
rows = torch.cat(rows)
print('cat took:', time.time() - t)
# print('rows:', rows)
rows = set(rows.tolist())
# print('rows:', rows)
t = time.time()
adj_mat = adjacency_matrix.coalesce()
indices = adj_mat.indices()
values = adj_mat.values()
print('indices[0]:', indices[0])
print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
# print('selection:', selection)
selection = torch.nonzero(selection, as_tuple=True)[0]
# print('selection:', selection)
indices = indices[:, selection]
values = values[selection]
print('"index_select()" took:', time.time() - t)
t = time.time()
res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
print('_sparse_coo_tensor() took:', time.time() - t)
return res
# t = time.time()
# adj_mat = torch.index_select(adjacency_matrix, 0, rows)
# print('index_select took:', time.time() - t)
t = time.time()
adj_mat = adj_mat.coalesce()
print('coalesce() took:', time.time() - t)
indices = adj_mat.indices()
# print('indices:', indices)
values = adj_mat.values()
t = time.time()
indices[0] = rows[indices[0]]
print('Lookup took:', time.time() - t)
t = time.time()
adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
print('_sparse_coo_tensor() took:', time.time() - t)
return adj_mat
def _sparse_diag_cat(matrices: List[torch.Tensor]):
if len(matrices) == 0:
raise ValueError('The list of matrices must be non-empty')
if not all(m.is_sparse for m in matrices):
raise ValueError('All matrices must be sparse')
if not all(len(m.shape) == 2 for m in matrices):
raise ValueError('All matrices must be 2D')
indices = []
values = []
row_offset = 0
col_offset = 0
for m in matrices:
ind = m._indices().clone()
ind[0] += row_offset
ind[1] += col_offset
indices.append(ind)
values.append(m._values())
row_offset += m.shape[0]
col_offset += m.shape[1]
indices = torch.cat(indices, dim=1)
values = torch.cat(values)
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
def _cat(matrices: List[torch.Tensor]):
if len(matrices) == 0:
raise ValueError('Empty list passed to _cat()')
n = sum(a.is_sparse for a in matrices)
if n != 0 and n != len(matrices):
raise ValueError('All matrices must have the same layout (dense or sparse)')
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
raise ValueError('All matrices must have the same dimensions apart from dimension 0')
if not matrices[0].is_sparse:
return torch.cat(matrices)
total_rows = sum(a.shape[0] for a in matrices)
indices = []
values = []
row_offset = 0
for a in matrices:
ind = a._indices().clone()
val = a._values()
ind[0] += row_offset
ind = ind.transpose(0, 1)
indices.append(ind)
values.append(val)
row_offset += a.shape[0]
indices = torch.cat(indices).transpose(0, 1)
values = torch.cat(values)
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
return res

+ 95
- 0
tests/triacontagon/test_util.py Прегледај датотеку

@@ -0,0 +1,95 @@
from triacontagon.util import \
_clear_adjacency_matrix_except_rows, \
_sparse_diag_cat, \
_equal
import torch
import time
def test_clear_adjacency_matrix_except_rows_01():
adj_mat = torch.tensor([
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 1],
[1, 0, 1, 0, 0],
[1, 1, 0, 0, 0]
], dtype=torch.uint8).to_sparse()
adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ])
res = _clear_adjacency_matrix_except_rows(adj_mat,
torch.tensor([1, 3]), 4, 2)
res = res.to_dense()
truth = torch.tensor([
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
], dtype=torch.uint8)
print('res:', res)
assert torch.all(res == truth)
def test_clear_adjacency_matrix_except_rows_02():
adj_mat = torch.rand(6, 10).round().to(torch.uint8)
t = time.time()
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
print('_sparse_diag_cat() took:', time.time() - t)
t = time.time()
res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
6, 130)
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
adj_mat[0] = adj_mat[2] = adj_mat[4] = \
torch.zeros(10)
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
assert _equal(res, truth).all()
def test_clear_adjacency_matrix_except_rows_03():
adj_mat = torch.rand(6, 10).round().to(torch.uint8)
t = time.time()
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
print('_sparse_diag_cat() took:', time.time() - t)
t = time.time()
res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
6, 1300)
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
adj_mat[0] = adj_mat[2] = adj_mat[4] = \
torch.zeros(10)
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
assert _equal(res, truth).all()
def test_clear_adjacency_matrix_except_rows_04():
adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
t = time.time()
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
print('_sparse_diag_cat() took:', time.time() - t)
t = time.time()
res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
2000, 1300)
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
adj_mat[0] = adj_mat[2] = adj_mat[4] = \
torch.zeros(2000)
adj_mat[6:] = torch.zeros(2000)
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
assert _equal(res, truth).all()

Loading…
Откажи
Сачувај