@@ -9,7 +9,8 @@ from typing import Callable, \ | |||
Tuple, \ | |||
List | |||
import types | |||
from .util import _nonzero_sum | |||
from .util import _nonzero_sum, \ | |||
_diag | |||
import torch | |||
@@ -61,13 +62,28 @@ class Data(object): | |||
name = str(name) | |||
vertex_type_row = int(vertex_type_row) | |||
vertex_type_column = int(vertex_type_column) | |||
if not isinstance(adjacency_matrices, list): | |||
raise TypeError('adjacency_matrices must be a list of tensors') | |||
if not isinstance(decoder_factory, types.FunctionType): | |||
raise TypeError('decoder_factory must be a function') | |||
if not callable(decoder_factory): | |||
raise TypeError('decoder_factory must be callable') | |||
if (vertex_type_row, vertex_type_column) in self.edge_types: | |||
raise KeyError('Edge type for given combination of row and column already exists') | |||
if vertex_type_row == vertex_type_column and \ | |||
any(torch.any(_diag(adj_mat).to(torch.bool)) \ | |||
for adj_mat in adjacency_matrices): | |||
raise ValueError('Adjacency matrices for same row/column vertex types must have empty diagonals') | |||
if any(adj_mat.shape[0] != self.vertex_types[vertex_type_row].count \ | |||
or adj_mat.shape[1] != self.vertex_types[vertex_type_column].count \ | |||
for adj_mat in adjacency_matrices): | |||
raise ValueError('Adjacency matrices must have as many rows as row vertex type count and as many columns as column vertex type count') | |||
total_connectivity = _nonzero_sum(adjacency_matrices) | |||
self.edge_types[vertex_type_row, vertex_type_column] = \ | |||
EdgeType(name, vertex_type_row, vertex_type_column, | |||
adjacency_matrices, decoder_factory, total_connectivity) |
@@ -120,13 +120,18 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |||
return true_classes, row_count | |||
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||
def negative_sample_adj_mat(adj_mat: torch.Tensor, | |||
remove_diagonal: bool=False) -> torch.Tensor: | |||
if not isinstance(adj_mat, torch.Tensor): | |||
raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | |||
edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||
true_classes, row_count = get_true_classes(adj_mat) | |||
if remove_diagonal: | |||
true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1), | |||
true_classes ], dim=1) | |||
# true_classes = edges_pos[:, 1].view(-1, 1) | |||
# print('true_classes:', true_classes) | |||
@@ -164,7 +169,10 @@ def negative_sample_data(data: Data) -> Data: | |||
for key, et in data.edge_types.items(): | |||
adjacency_matrices_neg = [] | |||
for adj_mat in et.adjacency_matrices: | |||
adj_mat_neg = negative_sample_adj_mat(adj_mat) | |||
remove_diagonal = True \ | |||
if et.vertex_type_row == et.vertex_type_column \ | |||
else False | |||
adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal) | |||
adjacency_matrices_neg.append(adj_mat_neg) | |||
res.add_edge_type(et.name, | |||
et.vertex_type_row, et.vertex_type_column, | |||
@@ -4,6 +4,29 @@ from typing import List, \ | |||
import time | |||
def _diag(x: torch.Tensor, make_sparse: bool=False): | |||
if len(x.shape) < 1 or len(x.shape) > 2: | |||
raise ValueError('Matrix or vector expected') | |||
if not x.is_sparse and not make_sparse: | |||
return torch.diag(x) | |||
if len(x.shape) == 1: | |||
indices = torch.arange(len(x)).view(1, -1) | |||
indices = torch.cat([ indices, indices ]) | |||
return _sparse_coo_tensor(indices, x.to_dense(), (len(x),) * 2) | |||
values = x.values() | |||
indices = x.indices() | |||
mask = torch.nonzero(indices[0] == indices[1], as_tuple=True)[0] | |||
indices = torch.flatten(indices[0, mask]) | |||
order = torch.argsort(indices) | |||
values = values[mask][order] | |||
res = torch.zeros(min(x.shape[0], x.shape[1]), dtype=values.dtype) | |||
res[indices] = values | |||
return res | |||
def _equal(x: torch.Tensor, y: torch.Tensor): | |||
if x.is_sparse ^ y.is_sparse: | |||
raise ValueError('Cannot mix sparse and dense tensors') | |||
@@ -33,7 +33,7 @@ def test_same_data_org_02(): | |||
torch.tensor([ | |||
[0, 0, 0, 1], | |||
[1, 0, 0, 0], | |||
[0, 1, 1, 0], | |||
[0, 1, 0, 1], | |||
[1, 0, 1, 0] | |||
]).to_sparse() | |||
], dedicom_decoder) | |||
@@ -46,7 +46,7 @@ def test_same_data_org_02(): | |||
torch.tensor([ | |||
[0, 0, 0, 1], | |||
[1, 0, 0, 0], | |||
[0, 1, 1, 0], | |||
[0, 1, 0, 1], | |||
[1, 0, 0, 0] | |||
]).to_sparse() | |||
], dedicom_decoder) | |||
@@ -94,7 +94,7 @@ def test_batcher_02(): | |||
]).to_sparse(), | |||
torch.tensor([ | |||
[1, 0, 1, 0, 0], | |||
[0, 0, 1, 0, 1], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
@@ -113,7 +113,7 @@ def test_batcher_02(): | |||
assert visited == { (0, 0, 1), (0, 0, 3), | |||
(0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), | |||
(1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), | |||
(1, 0, 2), (1, 0, 4), (1, 1, 3), (1, 2, 4), | |||
(1, 3, 1), (1, 4, 2) } | |||
@@ -132,7 +132,7 @@ def test_batcher_03(): | |||
]).to_sparse(), | |||
torch.tensor([ | |||
[1, 0, 1, 0, 0], | |||
[0, 0, 1, 0, 1], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
@@ -162,7 +162,7 @@ def test_batcher_03(): | |||
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | |||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | |||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | |||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | |||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | |||
@@ -211,7 +211,7 @@ def test_batcher_05(): | |||
]).to_sparse(), | |||
torch.tensor([ | |||
[1, 0, 1, 0, 0], | |||
[0, 0, 1, 0, 1], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
@@ -242,7 +242,7 @@ def test_batcher_05(): | |||
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | |||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | |||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | |||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | |||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | |||
@@ -264,7 +264,7 @@ def test_dual_batcher_01(): | |||
]).to_sparse(), | |||
torch.tensor([ | |||
[1, 0, 1, 0, 0], | |||
[0, 0, 1, 0, 1], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
@@ -306,7 +306,7 @@ def test_dual_batcher_01(): | |||
expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | |||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | |||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | |||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | |||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | |||
@@ -13,10 +13,10 @@ def test_per_layer_required_vertices_01(): | |||
d.add_vertex_type('Drug', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
[1, 0, 0, 1], | |||
[0, 1, 1, 0], | |||
[0, 0, 0, 1], | |||
[0, 0, 1, 0], | |||
[0, 1, 0, 1] | |||
[1, 0, 0, 0], | |||
[0, 1, 0, 0] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
@@ -27,11 +27,11 @@ def test_per_layer_required_vertices_01(): | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
[0, 0, 1, 0, 1], | |||
[0, 0, 0, 1, 1], | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1] | |||
[0, 1, 0, 0, 1], | |||
[1, 1, 0, 1, 0] | |||
]).to_sparse() ], dedicom_decoder) | |||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
@@ -48,10 +48,10 @@ def test_model_convolve_01(): | |||
d.add_vertex_type('Drug', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
[1, 0, 0, 1], | |||
[0, 1, 1, 0], | |||
[0, 0, 0, 1], | |||
[0, 0, 1, 0], | |||
[0, 1, 0, 1] | |||
[1, 0, 0, 0], | |||
[0, 1, 0, 0] | |||
], dtype=torch.float).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
@@ -62,11 +62,11 @@ def test_model_convolve_01(): | |||
], dtype=torch.float).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1] | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 1, 0] | |||
], dtype=torch.float).to_sparse() ], dedicom_decoder) | |||
model = Model(d, [9, 32, 64], keep_prob=1.0, | |||
@@ -90,8 +90,10 @@ def test_model_decode_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 100) | |||
gene_gene = torch.rand(100, 100).round() | |||
gene_gene = gene_gene - torch.diag(torch.diag(gene_gene)) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||
torch.rand(100, 100).round().to_sparse() | |||
gene_gene.to_sparse() | |||
], dedicom_decoder) | |||
b = TrainingBatch(0, 0, 0, torch.tensor([ | |||