From b607a727aeff18fd8c0e8952e769f7bd3611af46 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 6 Aug 2020 11:58:10 +0200 Subject: [PATCH] Rename _per_layer_required_rows to _vertices. --- src/triacontagon/model.py | 40 -------------------------------- src/triacontagon/util.py | 38 ++++++++++++++++++++++++++++++ tests/triacontagon/test_util.py | 41 ++++++++++++++++++++++++++++++++- 3 files changed, 78 insertions(+), 41 deletions(-) diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 387c5e3..8059a89 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -19,46 +19,6 @@ class TrainingBatch(object): edges: torch.Tensor -def _per_layer_required_rows(data: Data, batch: TrainingBatch, - num_layers: int) -> List[List[EdgeType]]: - - Q = [ - ( batch.vertex_type_row, batch.edges[:, 0] ), - ( batch.vertex_type_column, batch.edges[:, 1] ) - ] - print('Q:', Q) - res = [] - - for _ in range(num_layers): - R = [] - required_rows = [ [] for _ in range(len(data.vertex_types)) ] - - for vertex_type, vertices in Q: - for et in data.edge_types.values(): - if et.vertex_type_row == vertex_type: - required_rows[vertex_type].append(vertices) - indices = et.total_connectivity.indices() - mask = torch.zeros(et.total_connectivity.shape[0]) - mask[vertices] = 1 - mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] - R.append((et.vertex_type_column, - indices[1, mask])) - else: - pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) - - required_rows = [ torch.unique(torch.cat(x)) \ - if len(x) > 0 \ - else None \ - for x in required_rows ] - - res.append(required_rows) - Q = R - - return res - - - - class Model(torch.nn.Module): def __init__(self, data: Data, layer_dimensions: List[int], keep_prob: float, diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index 51e6d07..eb31bbc 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -190,3 +190,41 @@ def _cat(matrices: List[torch.Tensor]): res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) return res + + +def _per_layer_required_vertices(data: Data, batch: TrainingBatch, + num_layers: int) -> List[List[EdgeType]]: + + Q = [ + ( batch.vertex_type_row, batch.edges[:, 0] ), + ( batch.vertex_type_column, batch.edges[:, 1] ) + ] + print('Q:', Q) + res = [] + + for _ in range(num_layers): + R = [] + required_rows = [ [] for _ in range(len(data.vertex_types)) ] + + for vertex_type, vertices in Q: + for et in data.edge_types.values(): + if et.vertex_type_row == vertex_type: + required_rows[vertex_type].append(vertices) + indices = et.total_connectivity.indices() + mask = torch.zeros(et.total_connectivity.shape[0]) + mask[vertices] = 1 + mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] + R.append((et.vertex_type_column, + indices[1, mask])) + else: + pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) + + required_rows = [ torch.unique(torch.cat(x)) \ + if len(x) > 0 \ + else None \ + for x in required_rows ] + + res.append(required_rows) + Q = R + + return res diff --git a/tests/triacontagon/test_util.py b/tests/triacontagon/test_util.py index e4f7f9d..21f25ef 100644 --- a/tests/triacontagon/test_util.py +++ b/tests/triacontagon/test_util.py @@ -1,7 +1,11 @@ from triacontagon.util import \ _clear_adjacency_matrix_except_rows, \ _sparse_diag_cat, \ - _equal + _equal, \ + _per_layer_required_vertices +from triacontagon.model import TrainingBatch +from triacontagon.decode import dedicom_decoder +from triacontagon.data import Data import torch import time @@ -121,3 +125,38 @@ def test_clear_adjacency_matrix_except_rows_05(): truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) assert _equal(res, truth).all() + + +def test_per_layer_required_vertices_01(): + d = Data() + d.add_vertex_type('Gene', 4) + d.add_vertex_type('Drug', 5) + + d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ + [1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 1, 0, 1] + ]).to_sparse() ], dedicom_decoder) + + d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 0] + ]).to_sparse() ], dedicom_decoder) + + d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1] + ]).to_sparse() ], dedicom_decoder) + + batch = TrainingBatch(0, 1, 0, torch.tensor([ + [0, 1] + ])) + + res = _per_layer_required_vertices(d, batch, 5) + print('res:', res)