diff --git a/src/triacontagon/input.py b/src/triacontagon/deprecated/input.py similarity index 100% rename from src/triacontagon/input.py rename to src/triacontagon/deprecated/input.py diff --git a/src/triacontagon/dropout.py b/src/triacontagon/dropout.py index 63cfb58..2fb8728 100644 --- a/src/triacontagon/dropout.py +++ b/src/triacontagon/dropout.py @@ -36,6 +36,8 @@ def dropout_dense(x, keep_prob): def dropout(x, keep_prob): + if keep_prob == 1: + return x if x.is_sparse: return dropout_sparse(x, keep_prob) else: diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 49d37bd..58f635d 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -22,6 +22,7 @@ class TrainingBatch(object): vertex_type_column: int relation_type_index: int edges: torch.Tensor + target_values: torch.Tensor def _per_layer_required_vertices(data: Data, batch: TrainingBatch, @@ -133,7 +134,7 @@ class Model(torch.nn.Module): List[torch.Tensor]: cur_layer_repr = in_layer_repr - + for i in range(len(self.layer_dimensions) - 1): next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] @@ -147,7 +148,7 @@ class Model(torch.nn.Module): if self.keep_prob != 1: x = dropout(x, self.keep_prob) - print('a, Layer:', i, 'x.shape:', x.shape) + # print('a, Layer:', i, 'x.shape:', x.shape) x = _mm(x, conv_weights) x = torch.split(x, @@ -159,12 +160,12 @@ class Model(torch.nn.Module): self.data.vertex_types[vt_row].count, self.layer_dimensions[i + 1]) - print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) + # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) x = x.sum(dim=0) x = torch.nn.functional.normalize(x, p=2, dim=1) # x = self.rel_activation(x) - print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) + # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) next_layer_repr[vt_row].append(x) @@ -174,6 +175,32 @@ class Model(torch.nn.Module): cur_layer_repr = next_layer_repr return next_layer_repr + def decode(self, last_layer_repr: List[torch.Tensor], + batch: TrainingBatch) -> torch.Tensor: + + vt_row = batch.vertex_type_row + vt_col = batch.vertex_type_column + rel_idx = batch.relation_type_index + global_interaction = \ + self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) + local_variation = \ + self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) + + in_row = dropout(last_layer_repr[vt_row], self.keep_prob) + in_col = dropout(last_layer_repr[vt_col], self.keep_prob) + + x = torch.mm(in_row, local_variation) + x = torch.mm(x, global_interaction) + x = torch.mm(x, local_variation) + x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), + in_col.view(in_col.shape[0], in_col.shape[1], 1)) + x = torch.flatten(x) + + x = self.dec_activation(x) + + return x + + def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: edges = [] cur_edges = batch.edges diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index 134b64c..8d8bad4 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -197,3 +197,25 @@ def _mm(a: torch.Tensor, b: torch.Tensor): return torch.sparse.mm(a, b) else: return torch.mm(a, b) + + +def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ + List[torch.Tensor]: + + tot = sum(vertex_type_counts) + # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0) + # print('indices.shape:', indices.shape) + ofs = 0 + res = [] + + for cnt in vertex_type_counts: + ind = torch.cat([ + torch.arange(cnt).view(1, -1), + torch.arange(ofs, ofs+cnt).view(1, -1) + ]) + val = torch.ones(cnt) + x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) + res.append(x) + ofs += cnt + + return res diff --git a/tests/triacontagon/test_model.py b/tests/triacontagon/test_model.py index c943d8e..20656c9 100644 --- a/tests/triacontagon/test_model.py +++ b/tests/triacontagon/test_model.py @@ -1,9 +1,46 @@ import torch -from triacontagon.model import Model +from triacontagon.model import Model, \ + TrainingBatch, \ + _per_layer_required_vertices from triacontagon.data import Data from triacontagon.decode import dedicom_decoder +def test_per_layer_required_vertices_01(): + d = Data() + d.add_vertex_type('Gene', 4) + d.add_vertex_type('Drug', 5) + + d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ + [1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 1, 0, 1] + ]).to_sparse() ], dedicom_decoder) + + d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 0] + ]).to_sparse() ], dedicom_decoder) + + d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1] + ]).to_sparse() ], dedicom_decoder) + + batch = TrainingBatch(0, 1, 0, torch.tensor([ + [0, 1] + ])) + + res = _per_layer_required_vertices(d, batch, 5) + print('res:', res) + + def test_model_convolve_01(): d = Data() d.add_vertex_type('Gene', 4) diff --git a/tests/triacontagon/test_util.py b/tests/triacontagon/test_util.py index 21f25ef..3593d1d 100644 --- a/tests/triacontagon/test_util.py +++ b/tests/triacontagon/test_util.py @@ -2,7 +2,7 @@ from triacontagon.util import \ _clear_adjacency_matrix_except_rows, \ _sparse_diag_cat, \ _equal, \ - _per_layer_required_vertices + common_one_hot_encoding from triacontagon.model import TrainingBatch from triacontagon.decode import dedicom_decoder from triacontagon.data import Data @@ -127,36 +127,9 @@ def test_clear_adjacency_matrix_except_rows_05(): assert _equal(res, truth).all() -def test_per_layer_required_vertices_01(): - d = Data() - d.add_vertex_type('Gene', 4) - d.add_vertex_type('Drug', 5) +def test_common_one_hot_encoding_01(): + in_repr = common_one_hot_encoding([2000, 200]) - d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ - [1, 0, 0, 1], - [0, 1, 1, 0], - [0, 0, 1, 0], - [0, 1, 0, 1] - ]).to_sparse() ], dedicom_decoder) - - d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ - [0, 1, 0, 0, 1], - [0, 0, 1, 0, 0], - [1, 0, 0, 0, 1], - [0, 0, 1, 1, 0] - ]).to_sparse() ], dedicom_decoder) - - d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 1] - ]).to_sparse() ], dedicom_decoder) - - batch = TrainingBatch(0, 1, 0, torch.tensor([ - [0, 1] - ])) - - res = _per_layer_required_vertices(d, batch, 5) - print('res:', res) + ref = torch.eye(2200) + assert torch.all(in_repr[0].to_dense() == ref[:2000, :]) + assert torch.all(in_repr[1].to_dense() == ref[2000:, :])