@@ -9,7 +9,8 @@ from typing import Callable, \ | |||||
Tuple, \ | Tuple, \ | ||||
List | List | ||||
import types | import types | ||||
from .util import _nonzero_sum | |||||
from .util import _nonzero_sum, \ | |||||
_diag | |||||
import torch | import torch | ||||
@@ -61,13 +62,28 @@ class Data(object): | |||||
name = str(name) | name = str(name) | ||||
vertex_type_row = int(vertex_type_row) | vertex_type_row = int(vertex_type_row) | ||||
vertex_type_column = int(vertex_type_column) | vertex_type_column = int(vertex_type_column) | ||||
if not isinstance(adjacency_matrices, list): | if not isinstance(adjacency_matrices, list): | ||||
raise TypeError('adjacency_matrices must be a list of tensors') | raise TypeError('adjacency_matrices must be a list of tensors') | ||||
if not isinstance(decoder_factory, types.FunctionType): | |||||
raise TypeError('decoder_factory must be a function') | |||||
if not callable(decoder_factory): | |||||
raise TypeError('decoder_factory must be callable') | |||||
if (vertex_type_row, vertex_type_column) in self.edge_types: | if (vertex_type_row, vertex_type_column) in self.edge_types: | ||||
raise KeyError('Edge type for given combination of row and column already exists') | raise KeyError('Edge type for given combination of row and column already exists') | ||||
if vertex_type_row == vertex_type_column and \ | |||||
any(torch.any(_diag(adj_mat).to(torch.bool)) \ | |||||
for adj_mat in adjacency_matrices): | |||||
raise ValueError('Adjacency matrices for same row/column vertex types must have empty diagonals') | |||||
if any(adj_mat.shape[0] != self.vertex_types[vertex_type_row].count \ | |||||
or adj_mat.shape[1] != self.vertex_types[vertex_type_column].count \ | |||||
for adj_mat in adjacency_matrices): | |||||
raise ValueError('Adjacency matrices must have as many rows as row vertex type count and as many columns as column vertex type count') | |||||
total_connectivity = _nonzero_sum(adjacency_matrices) | total_connectivity = _nonzero_sum(adjacency_matrices) | ||||
self.edge_types[vertex_type_row, vertex_type_column] = \ | self.edge_types[vertex_type_row, vertex_type_column] = \ | ||||
EdgeType(name, vertex_type_row, vertex_type_column, | EdgeType(name, vertex_type_row, vertex_type_column, | ||||
adjacency_matrices, decoder_factory, total_connectivity) | adjacency_matrices, decoder_factory, total_connectivity) |
@@ -120,13 +120,18 @@ def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] | |||||
return true_classes, row_count | return true_classes, row_count | ||||
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
def negative_sample_adj_mat(adj_mat: torch.Tensor, | |||||
remove_diagonal: bool=False) -> torch.Tensor: | |||||
if not isinstance(adj_mat, torch.Tensor): | if not isinstance(adj_mat, torch.Tensor): | ||||
raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) | ||||
edges_pos, degrees = get_edges_and_degrees(adj_mat) | edges_pos, degrees = get_edges_and_degrees(adj_mat) | ||||
true_classes, row_count = get_true_classes(adj_mat) | true_classes, row_count = get_true_classes(adj_mat) | ||||
if remove_diagonal: | |||||
true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1), | |||||
true_classes ], dim=1) | |||||
# true_classes = edges_pos[:, 1].view(-1, 1) | # true_classes = edges_pos[:, 1].view(-1, 1) | ||||
# print('true_classes:', true_classes) | # print('true_classes:', true_classes) | ||||
@@ -164,7 +169,10 @@ def negative_sample_data(data: Data) -> Data: | |||||
for key, et in data.edge_types.items(): | for key, et in data.edge_types.items(): | ||||
adjacency_matrices_neg = [] | adjacency_matrices_neg = [] | ||||
for adj_mat in et.adjacency_matrices: | for adj_mat in et.adjacency_matrices: | ||||
adj_mat_neg = negative_sample_adj_mat(adj_mat) | |||||
remove_diagonal = True \ | |||||
if et.vertex_type_row == et.vertex_type_column \ | |||||
else False | |||||
adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal) | |||||
adjacency_matrices_neg.append(adj_mat_neg) | adjacency_matrices_neg.append(adj_mat_neg) | ||||
res.add_edge_type(et.name, | res.add_edge_type(et.name, | ||||
et.vertex_type_row, et.vertex_type_column, | et.vertex_type_row, et.vertex_type_column, | ||||
@@ -4,6 +4,29 @@ from typing import List, \ | |||||
import time | import time | ||||
def _diag(x: torch.Tensor, make_sparse: bool=False): | |||||
if len(x.shape) < 1 or len(x.shape) > 2: | |||||
raise ValueError('Matrix or vector expected') | |||||
if not x.is_sparse and not make_sparse: | |||||
return torch.diag(x) | |||||
if len(x.shape) == 1: | |||||
indices = torch.arange(len(x)).view(1, -1) | |||||
indices = torch.cat([ indices, indices ]) | |||||
return _sparse_coo_tensor(indices, x.to_dense(), (len(x),) * 2) | |||||
values = x.values() | |||||
indices = x.indices() | |||||
mask = torch.nonzero(indices[0] == indices[1], as_tuple=True)[0] | |||||
indices = torch.flatten(indices[0, mask]) | |||||
order = torch.argsort(indices) | |||||
values = values[mask][order] | |||||
res = torch.zeros(min(x.shape[0], x.shape[1]), dtype=values.dtype) | |||||
res[indices] = values | |||||
return res | |||||
def _equal(x: torch.Tensor, y: torch.Tensor): | def _equal(x: torch.Tensor, y: torch.Tensor): | ||||
if x.is_sparse ^ y.is_sparse: | if x.is_sparse ^ y.is_sparse: | ||||
raise ValueError('Cannot mix sparse and dense tensors') | raise ValueError('Cannot mix sparse and dense tensors') | ||||
@@ -33,7 +33,7 @@ def test_same_data_org_02(): | |||||
torch.tensor([ | torch.tensor([ | ||||
[0, 0, 0, 1], | [0, 0, 0, 1], | ||||
[1, 0, 0, 0], | [1, 0, 0, 0], | ||||
[0, 1, 1, 0], | |||||
[0, 1, 0, 1], | |||||
[1, 0, 1, 0] | [1, 0, 1, 0] | ||||
]).to_sparse() | ]).to_sparse() | ||||
], dedicom_decoder) | ], dedicom_decoder) | ||||
@@ -46,7 +46,7 @@ def test_same_data_org_02(): | |||||
torch.tensor([ | torch.tensor([ | ||||
[0, 0, 0, 1], | [0, 0, 0, 1], | ||||
[1, 0, 0, 0], | [1, 0, 0, 0], | ||||
[0, 1, 1, 0], | |||||
[0, 1, 0, 1], | |||||
[1, 0, 0, 0] | [1, 0, 0, 0] | ||||
]).to_sparse() | ]).to_sparse() | ||||
], dedicom_decoder) | ], dedicom_decoder) | ||||
@@ -94,7 +94,7 @@ def test_batcher_02(): | |||||
]).to_sparse(), | ]).to_sparse(), | ||||
torch.tensor([ | torch.tensor([ | ||||
[1, 0, 1, 0, 0], | |||||
[0, 0, 1, 0, 1], | |||||
[0, 0, 0, 1, 0], | [0, 0, 0, 1, 0], | ||||
[0, 0, 0, 0, 1], | [0, 0, 0, 0, 1], | ||||
[0, 1, 0, 0, 0], | [0, 1, 0, 0, 0], | ||||
@@ -113,7 +113,7 @@ def test_batcher_02(): | |||||
assert visited == { (0, 0, 1), (0, 0, 3), | assert visited == { (0, 0, 1), (0, 0, 3), | ||||
(0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), | (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), | ||||
(1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), | |||||
(1, 0, 2), (1, 0, 4), (1, 1, 3), (1, 2, 4), | |||||
(1, 3, 1), (1, 4, 2) } | (1, 3, 1), (1, 4, 2) } | ||||
@@ -132,7 +132,7 @@ def test_batcher_03(): | |||||
]).to_sparse(), | ]).to_sparse(), | ||||
torch.tensor([ | torch.tensor([ | ||||
[1, 0, 1, 0, 0], | |||||
[0, 0, 1, 0, 1], | |||||
[0, 0, 0, 1, 0], | [0, 0, 0, 1, 0], | ||||
[0, 0, 0, 0, 1], | [0, 0, 0, 0, 1], | ||||
[0, 1, 0, 0, 0], | [0, 1, 0, 0, 0], | ||||
@@ -162,7 +162,7 @@ def test_batcher_03(): | |||||
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | ||||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | ||||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | ||||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | ||||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | ||||
@@ -211,7 +211,7 @@ def test_batcher_05(): | |||||
]).to_sparse(), | ]).to_sparse(), | ||||
torch.tensor([ | torch.tensor([ | ||||
[1, 0, 1, 0, 0], | |||||
[0, 0, 1, 0, 1], | |||||
[0, 0, 0, 1, 0], | [0, 0, 0, 1, 0], | ||||
[0, 0, 0, 0, 1], | [0, 0, 0, 0, 1], | ||||
[0, 1, 0, 0, 0], | [0, 1, 0, 0, 0], | ||||
@@ -242,7 +242,7 @@ def test_batcher_05(): | |||||
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | ||||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | ||||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | ||||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | ||||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | ||||
@@ -264,7 +264,7 @@ def test_dual_batcher_01(): | |||||
]).to_sparse(), | ]).to_sparse(), | ||||
torch.tensor([ | torch.tensor([ | ||||
[1, 0, 1, 0, 0], | |||||
[0, 0, 1, 0, 1], | |||||
[0, 0, 0, 1, 0], | [0, 0, 0, 1, 0], | ||||
[0, 0, 0, 0, 1], | [0, 0, 0, 0, 1], | ||||
[0, 1, 0, 0, 0], | [0, 1, 0, 0, 0], | ||||
@@ -306,7 +306,7 @@ def test_dual_batcher_01(): | |||||
expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | ||||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | ||||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | ||||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | ||||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | ||||
@@ -13,10 +13,10 @@ def test_per_layer_required_vertices_01(): | |||||
d.add_vertex_type('Drug', 5) | d.add_vertex_type('Drug', 5) | ||||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | ||||
[1, 0, 0, 1], | |||||
[0, 1, 1, 0], | |||||
[0, 0, 0, 1], | |||||
[0, 0, 1, 0], | [0, 0, 1, 0], | ||||
[0, 1, 0, 1] | |||||
[1, 0, 0, 0], | |||||
[0, 1, 0, 0] | |||||
]).to_sparse() ], dedicom_decoder) | ]).to_sparse() ], dedicom_decoder) | ||||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | ||||
@@ -27,11 +27,11 @@ def test_per_layer_required_vertices_01(): | |||||
]).to_sparse() ], dedicom_decoder) | ]).to_sparse() ], dedicom_decoder) | ||||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | ||||
[0, 0, 1, 0, 1], | |||||
[0, 0, 0, 1, 1], | |||||
[1, 0, 0, 0, 0], | [1, 0, 0, 0, 0], | ||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 0, 0, 1] | |||||
[0, 1, 0, 0, 1], | |||||
[1, 1, 0, 1, 0] | |||||
]).to_sparse() ], dedicom_decoder) | ]).to_sparse() ], dedicom_decoder) | ||||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | batch = TrainingBatch(0, 1, 0, torch.tensor([ | ||||
@@ -48,10 +48,10 @@ def test_model_convolve_01(): | |||||
d.add_vertex_type('Drug', 5) | d.add_vertex_type('Drug', 5) | ||||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | ||||
[1, 0, 0, 1], | |||||
[0, 1, 1, 0], | |||||
[0, 0, 0, 1], | |||||
[0, 0, 1, 0], | [0, 0, 1, 0], | ||||
[0, 1, 0, 1] | |||||
[1, 0, 0, 0], | |||||
[0, 1, 0, 0] | |||||
], dtype=torch.float).to_sparse() ], dedicom_decoder) | ], dtype=torch.float).to_sparse() ], dedicom_decoder) | ||||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | ||||
@@ -62,11 +62,11 @@ def test_model_convolve_01(): | |||||
], dtype=torch.float).to_sparse() ], dedicom_decoder) | ], dtype=torch.float).to_sparse() ], dedicom_decoder) | ||||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | ||||
[1, 0, 0, 0, 0], | |||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | [0, 0, 0, 1, 0], | ||||
[0, 0, 0, 0, 1] | |||||
[0, 0, 0, 0, 1], | |||||
[0, 1, 0, 0, 0], | |||||
[1, 0, 0, 0, 0], | |||||
[0, 1, 0, 1, 0] | |||||
], dtype=torch.float).to_sparse() ], dedicom_decoder) | ], dtype=torch.float).to_sparse() ], dedicom_decoder) | ||||
model = Model(d, [9, 32, 64], keep_prob=1.0, | model = Model(d, [9, 32, 64], keep_prob=1.0, | ||||
@@ -90,8 +90,10 @@ def test_model_decode_01(): | |||||
d = Data() | d = Data() | ||||
d.add_vertex_type('Gene', 100) | d.add_vertex_type('Gene', 100) | ||||
gene_gene = torch.rand(100, 100).round() | |||||
gene_gene = gene_gene - torch.diag(torch.diag(gene_gene)) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ | d.add_edge_type('Gene-Gene', 0, 0, [ | ||||
torch.rand(100, 100).round().to_sparse() | |||||
gene_gene.to_sparse() | |||||
], dedicom_decoder) | ], dedicom_decoder) | ||||
b = TrainingBatch(0, 0, 0, torch.tensor([ | b = TrainingBatch(0, 0, 0, torch.tensor([ | ||||