From 4ed6626e02879b742a9c0de800ff782d97a305a7 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 6 Aug 2020 10:24:06 +0200 Subject: [PATCH] Work on required vertices per layer. --- docs/required-vertices-per-layer.svg | 1463 ++++++++++++++++++++++++++ src/triacontagon/data.py | 5 +- src/triacontagon/decode.py | 10 +- src/triacontagon/model.py | 67 +- src/triacontagon/util.py | 4 +- tests/triacontagon/test_model.py | 40 + 6 files changed, 1568 insertions(+), 21 deletions(-) create mode 100644 docs/required-vertices-per-layer.svg create mode 100644 tests/triacontagon/test_model.py diff --git a/docs/required-vertices-per-layer.svg b/docs/required-vertices-per-layer.svg new file mode 100644 index 0000000..f7de037 --- /dev/null +++ b/docs/required-vertices-per-layer.svg @@ -0,0 +1,1463 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + 0 + + 1 + + 2 + + 4 + + 3 + + 0 + + 1 + + 2 + + 3 + + + + + + + + + + + + + + 0 + + 1 + + 2 + + 4 + + 3 + + 0 + + 1 + + 2 + + 3 + + + + + + + + + + + + + + + + 0 + + 1 + + 2 + + 4 + + 3 + + 0 + + 1 + + 2 + + 3 + + + + + + + + + + + + + + + + + + + 0 + + 1 + + 2 + + 4 + + 3 + + 0 + + 1 + + 2 + + 3 + + + + + + + + + + + + + + + + + + + + 0 + + 1 + + 2 + + 4 + + 3 + + 0 + + 1 + + 2 + + 3 + + + + + + + + + + + + + + + + + + + + + diff --git a/src/triacontagon/data.py b/src/triacontagon/data.py index 22a4c89..ba2b7f8 100644 --- a/src/triacontagon/data.py +++ b/src/triacontagon/data.py @@ -10,6 +10,7 @@ from typing import Callable, \ List import types from .util import _nonzero_sum +import torch @dataclass @@ -66,6 +67,6 @@ class Data(object): if (vertex_type_row, vertex_type_column) in self.edge_types: raise KeyError('Edge type for given combination of row and column already exists') total_connectivity = _nonzero_sum(adjacency_matrices) - self.edges_types[vertex_type_row, vertex_type_column] = \ - VertexType(name, vertex_type_row, vertex_type_column, + self.edge_types[vertex_type_row, vertex_type_column] = \ + EdgeType(name, vertex_type_row, vertex_type_column, adjacency_matrices, decoder_factory, total_connectivity) diff --git a/src/triacontagon/decode.py b/src/triacontagon/decode.py index 25ae822..d82d29d 100644 --- a/src/triacontagon/decode.py +++ b/src/triacontagon/decode.py @@ -11,7 +11,7 @@ from typing import Tuple, \ List -def dedicom_decoder(input_dim: int, num_relation_types: int) -> +def dedicom_decoder(input_dim: int, num_relation_types: int) -> \ Tuple[torch.Tensor, List[torch.Tensor]]: global_interaction = init_glorot(input_dim, input_dim) @@ -22,18 +22,18 @@ def dedicom_decoder(input_dim: int, num_relation_types: int) -> return (global_interaction, local_variation) -def dist_mult_decoder(input_dim: int, num_relation_types: int) -> +def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \ Tuple[torch.Tensor, List[torch.Tensor]]: global_interaction = torch.eye(input_dim, input_dim) local_variation = [ - torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ + torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ for _ in range(num_relation_types) ] return (global_interaction, local_variation) -def bilinear_decoder(input_dim: int, num_relation_types: int) -> +def bilinear_decoder(input_dim: int, num_relation_types: int) -> \ Tuple[torch.Tensor, List[torch.Tensor]]: global_interaction = torch.eye(input_dim, input_dim) @@ -44,7 +44,7 @@ def bilinear_decoder(input_dim: int, num_relation_types: int) -> return (global_interaction, local_variation) -def inner_product_decoder(input_dim: int, num_relation_types: int) -> +def inner_product_decoder(input_dim: int, num_relation_types: int) -> \ Tuple[torch.Tensor, List[torch.Tensor]]: global_interaction = torch.eye(input_dim, input_dim) diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 1b98931..387c5e3 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -6,7 +6,8 @@ from .weights import init_glorot import types from typing import List, \ Dict, \ - Callable + Callable, \ + Tuple from .util import _sparse_coo_tensor @@ -18,6 +19,46 @@ class TrainingBatch(object): edges: torch.Tensor +def _per_layer_required_rows(data: Data, batch: TrainingBatch, + num_layers: int) -> List[List[EdgeType]]: + + Q = [ + ( batch.vertex_type_row, batch.edges[:, 0] ), + ( batch.vertex_type_column, batch.edges[:, 1] ) + ] + print('Q:', Q) + res = [] + + for _ in range(num_layers): + R = [] + required_rows = [ [] for _ in range(len(data.vertex_types)) ] + + for vertex_type, vertices in Q: + for et in data.edge_types.values(): + if et.vertex_type_row == vertex_type: + required_rows[vertex_type].append(vertices) + indices = et.total_connectivity.indices() + mask = torch.zeros(et.total_connectivity.shape[0]) + mask[vertices] = 1 + mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] + R.append((et.vertex_type_column, + indices[1, mask])) + else: + pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) + + required_rows = [ torch.unique(torch.cat(x)) \ + if len(x) > 0 \ + else None \ + for x in required_rows ] + + res.append(required_rows) + Q = R + + return res + + + + class Model(torch.nn.Module): def __init__(self, data: Data, layer_dimensions: List[int], keep_prob: float, @@ -68,17 +109,16 @@ class Model(torch.nn.Module): torch.nn.Parameter(local_variation) ]) - def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, - rows: torch.Tensor) -> torch.Tensor: - - adj_mat = adjacency_matrix.coalesce() - adj_mat = torch.index_select(adj_mat, 0, rows) - adj_mat = adj_mat.coalesce() - indices = adj_mat.indices() - indices[0] = rows + def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: + edges = [] + cur_edges = batch.edges + for _ in range(len(self.layer_dimensions) - 1): + edges.append(cur_edges) + key = (batch.vertex_type_row, batch.vertex_type_column) + tot_conn = self.data.relation_types[key].total_connectivity + cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) - adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: @@ -90,12 +130,13 @@ class Model(torch.nn.Module): columns = torch.nonzero(columns) for i in range(len(self.layer_dimensions) - 1): + pass # columns = + # TODO: finish - columns = + return None - def temporary_adjacency_matrices(self, batch: TrainingBatch) -> - Dict[Tuple[int, int], List[List[torch.Tensor]]]: + def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: col = batch.vertex_type_column batch.edges[:, 1] diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index e6fefc8..51e6d07 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -41,7 +41,9 @@ def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): indices = res.indices() res = _sparse_coo_tensor(indices, - torch.ones(indices.shape[1], dtype=torch.uint8)) + torch.ones(indices.shape[1], dtype=torch.uint8), + adjacency_matrices[0].shape) + res = res.coalesce() return res diff --git a/tests/triacontagon/test_model.py b/tests/triacontagon/test_model.py new file mode 100644 index 0000000..b887713 --- /dev/null +++ b/tests/triacontagon/test_model.py @@ -0,0 +1,40 @@ +from triacontagon.model import _per_layer_required_rows, \ + TrainingBatch +from triacontagon.decode import dedicom_decoder +from triacontagon.data import Data +import torch + + +def test_per_layer_required_rows_01(): + d = Data() + d.add_vertex_type('Gene', 4) + d.add_vertex_type('Drug', 5) + + d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ + [1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 1, 0, 1] + ]).to_sparse() ], dedicom_decoder) + + d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ + [0, 1, 0, 0, 1], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 0] + ]).to_sparse() ], dedicom_decoder) + + d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1] + ]).to_sparse() ], dedicom_decoder) + + batch = TrainingBatch(0, 1, 0, torch.tensor([ + [0, 1] + ])) + + res = _per_layer_required_rows(d, batch, 5) + print('res:', res)