From 7ea82a626756080c0c8ac5a32ead0c1f58341426 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 7 Aug 2020 11:21:58 +0200 Subject: [PATCH] Add test_model_decode_01(). --- src/triacontagon/model.py | 27 +++++++++++++++++++++------ src/triacontagon/util.py | 26 ++++++++++++++++++++++++++ tests/triacontagon/test_model.py | 26 ++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 58f635d..1a41b4d 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -173,7 +173,7 @@ class Model(torch.nn.Module): for x in next_layer_repr ] cur_layer_repr = next_layer_repr - return next_layer_repr + return cur_layer_repr def decode(self, last_layer_repr: List[torch.Tensor], batch: TrainingBatch) -> torch.Tensor: @@ -182,12 +182,27 @@ class Model(torch.nn.Module): vt_col = batch.vertex_type_column rel_idx = batch.relation_type_index global_interaction = \ - self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) + self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)] local_variation = \ - self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) + self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)] - in_row = dropout(last_layer_repr[vt_row], self.keep_prob) - in_col = dropout(last_layer_repr[vt_col], self.keep_prob) + in_row = last_layer_repr[vt_row] + in_col = last_layer_repr[vt_col] + + if in_row.is_sparse or in_col.is_sparse: + raise ValueError('Inputs to Model.decode() must be dense') + + in_row = in_row[batch.edges[:, 0]] + in_col = in_col[batch.edges[:, 1]] + + in_row = dropout(in_row, self.keep_prob) + in_col = dropout(in_col, self.keep_prob) + + # in_row = in_row.to_dense() + # in_col = in_col.to_dense() + + print('in_row.is_sparse:', in_row.is_sparse) + print('in_col.is_sparse:', in_col.is_sparse) x = torch.mm(in_row, local_variation) x = torch.mm(x, global_interaction) @@ -197,7 +212,7 @@ class Model(torch.nn.Module): x = torch.flatten(x) x = self.dec_activation(x) - + return x diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index 8d8bad4..f268f44 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -199,6 +199,32 @@ def _mm(a: torch.Tensor, b: torch.Tensor): return torch.mm(a, b) +def _select_rows(a: torch.Tensor, rows: torch.Tensor): + if not a.is_sparse: + return a[rows] + + indices = a.indices() + values = a.values() + + mask = torch.zeros(a.shape[0]) + mask[rows] = 1 + if mask.sum() != len(rows): + raise ValueError('Rows must be unique') + mask = mask[indices[0]] + mask = torch.nonzero(mask, as_tuple=True)[0] + + new_rows[rows] = torch.arange(len(rows)) + new_rows = new_rows[indices[0]] + + indices = indices[:, mask] + indices[0] = new_rows + values = values[mask] + + res = _sparse_coo_tensor(indices, values, + size=(len(rows), a.shape[1])) + return res + + def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ List[torch.Tensor]: diff --git a/tests/triacontagon/test_model.py b/tests/triacontagon/test_model.py index 20656c9..0f57578 100644 --- a/tests/triacontagon/test_model.py +++ b/tests/triacontagon/test_model.py @@ -4,6 +4,7 @@ from triacontagon.model import Model, \ _per_layer_required_vertices from triacontagon.data import Data from triacontagon.decode import dedicom_decoder +from triacontagon.util import common_one_hot_encoding def test_per_layer_required_vertices_01(): @@ -83,3 +84,28 @@ def test_model_convolve_01(): ] _ = model.convolve(in_layer_repr) + + +def test_model_decode_01(): + d = Data() + d.add_vertex_type('Gene', 100) + + d.add_edge_type('Gene-Gene', 0, 0, [ + torch.rand(100, 100).round().to_sparse() + ], dedicom_decoder) + + b = TrainingBatch(0, 0, 0, torch.tensor([ + [0, 1], + [10, 51], + [50, 60], + [70, 90], + [98, 99] + ]), torch.ones(5)) + + in_repr = common_one_hot_encoding([100]) + + in_repr = [ in_repr[0].to_dense() ] + + m = Model(d, [100], 1.0, torch.sigmoid, torch.sigmoid) + + _ = m.decode(in_repr, b)