From bd8143469c54d99fbfcd7301dc3c3264d888bb8a Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 6 Aug 2020 15:12:09 +0200 Subject: [PATCH] Make work convolve() in new model. --- src/triacontagon/{ => deprecated}/fastconv.py | 0 src/triacontagon/{ => deprecated}/fastdec.py | 0 src/triacontagon/{ => deprecated}/fastloop.py | 0 .../{ => deprecated}/fastmodel.py | 0 .../{ => deprecated}/trainprep.py | 0 src/triacontagon/model.py | 126 ++++++++++++++++-- src/triacontagon/util.py | 41 +----- tests/triacontagon/test_model.py | 34 +++-- 8 files changed, 138 insertions(+), 63 deletions(-) rename src/triacontagon/{ => deprecated}/fastconv.py (100%) rename src/triacontagon/{ => deprecated}/fastdec.py (100%) rename src/triacontagon/{ => deprecated}/fastloop.py (100%) rename src/triacontagon/{ => deprecated}/fastmodel.py (100%) rename src/triacontagon/{ => deprecated}/trainprep.py (100%) diff --git a/src/triacontagon/fastconv.py b/src/triacontagon/deprecated/fastconv.py similarity index 100% rename from src/triacontagon/fastconv.py rename to src/triacontagon/deprecated/fastconv.py diff --git a/src/triacontagon/fastdec.py b/src/triacontagon/deprecated/fastdec.py similarity index 100% rename from src/triacontagon/fastdec.py rename to src/triacontagon/deprecated/fastdec.py diff --git a/src/triacontagon/fastloop.py b/src/triacontagon/deprecated/fastloop.py similarity index 100% rename from src/triacontagon/fastloop.py rename to src/triacontagon/deprecated/fastloop.py diff --git a/src/triacontagon/fastmodel.py b/src/triacontagon/deprecated/fastmodel.py similarity index 100% rename from src/triacontagon/fastmodel.py rename to src/triacontagon/deprecated/fastmodel.py diff --git a/src/triacontagon/trainprep.py b/src/triacontagon/deprecated/trainprep.py similarity index 100% rename from src/triacontagon/trainprep.py rename to src/triacontagon/deprecated/trainprep.py diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 8059a89..49ca1bb 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -8,7 +8,12 @@ from typing import List, \ Dict, \ Callable, \ Tuple -from .util import _sparse_coo_tensor +from .util import _sparse_coo_tensor, \ + _sparse_diag_cat, \ + _mm +from .normalize import norm_adj_mat_one_node_type, \ + norm_adj_mat_two_node_types +from .dropout import dropout @dataclass @@ -19,6 +24,44 @@ class TrainingBatch(object): edges: torch.Tensor +def _per_layer_required_vertices(data: Data, batch: TrainingBatch, + num_layers: int) -> List[List[EdgeType]]: + + Q = [ + ( batch.vertex_type_row, batch.edges[:, 0] ), + ( batch.vertex_type_column, batch.edges[:, 1] ) + ] + print('Q:', Q) + res = [] + + for _ in range(num_layers): + R = [] + required_rows = [ [] for _ in range(len(data.vertex_types)) ] + + for vertex_type, vertices in Q: + for et in data.edge_types.values(): + if et.vertex_type_row == vertex_type: + required_rows[vertex_type].append(vertices) + indices = et.total_connectivity.indices() + mask = torch.zeros(et.total_connectivity.shape[0]) + mask[vertices] = 1 + mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] + R.append((et.vertex_type_column, + indices[1, mask])) + else: + pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) + + required_rows = [ torch.unique(torch.cat(x)) \ + if len(x) > 0 \ + else None \ + for x in required_rows ] + + res.append(required_rows) + Q = R + + return res + + class Model(torch.nn.Module): def __init__(self, data: Data, layer_dimensions: List[int], keep_prob: float, @@ -30,11 +73,11 @@ class Model(torch.nn.Module): if not isinstance(data, Data): raise TypeError('data must be an instance of Data') - if not isinstance(conv_activation, types.FunctionType): - raise TypeError('conv_activation must be a function') + if not callable(conv_activation): + raise TypeError('conv_activation must be callable') - if not isinstance(dec_activation, types.FunctionType): - raise TypeError('dec_activation must be a function') + if not callable(dec_activation): + raise TypeError('dec_activation must be callable') self.data = data self.layer_dimensions = list(layer_dimensions) @@ -42,35 +85,90 @@ class Model(torch.nn.Module): self.conv_activation = conv_activation self.dec_activation = dec_activation + self.adj_matrices = None self.conv_weights = None self.dec_weights = None self.build() def build(self) -> None: + self.adj_matrices = torch.nn.ParameterDict() + for _, et in self.data.edge_types.items(): + adj_matrices = [ + norm_adj_mat_one_node_type(x) \ + if et.vertex_type_row == et.vertex_type_column \ + else norm_adj_mat_two_node_types(x) \ + for x in et.adjacency_matrices + ] + adj_matrices = _sparse_diag_cat(et.adjacency_matrices) + print('adj_matrices:', adj_matrices) + self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \ + torch.nn.Parameter(adj_matrices, requires_grad=False) + self.conv_weights = torch.nn.ParameterDict() for i in range(len(self.layer_dimensions) - 1): in_dimension = self.layer_dimensions[i] out_dimension = self.layer_dimensions[i + 1] for _, et in self.data.edge_types.items(): - weight = init_glorot(in_dimension, out_dimension) - self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \ - torch.nn.Parameter(weight) + weights = [ init_glorot(in_dimension, out_dimension) \ + for _ in range(len(et.adjacency_matrices)) ] + weights = torch.cat(weights, dim=1) + self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \ + torch.nn.Parameter(weights) self.dec_weights = torch.nn.ParameterDict() for _, et in self.data.edge_types.items(): global_interaction, local_variation = \ et.decoder_factory(self.layer_dimensions[-1], len(et.adjacency_matrices)) - self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \ - torch.nn.ParameterList([ - torch.nn.Parameter(global_interaction), - torch.nn.Parameter(local_variation) - ]) + self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \ + torch.nn.Parameter(global_interaction) + for i in range(len(local_variation)): + self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \ + torch.nn.Parameter(local_variation[i]) + + def convolve(self, in_layer_repr: List[torch.Tensor]) -> \ + List[torch.Tensor]: - def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: + cur_layer_repr = in_layer_repr + next_layer_repr = [ None ] * len(self.data.vertex_types) + + for i in range(len(self.layer_dimensions) - 1): + for _, et in self.data.edge_types.items(): + vt_row, vt_col = et.vertex_type_row, et.vertex_type_column + adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)] + conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)] + + num_relation_types = len(et.adjacency_matrices) + x = cur_layer_repr[vt_col] + if self.keep_prob != 1: + x = dropout(x, self.keep_prob) + + print('a, Layer:', i, 'x.shape:', x.shape) + + x = _mm(x, conv_weights) + x = torch.split(x, + x.shape[1] // num_relation_types, + dim=1) + x = torch.cat(x) + x = _mm(adj_matrices, x) + x = x.view(num_relation_types, + self.data.vertex_types[vt_row].count, + self.layer_dimensions[i + 1]) + + print('b, Layer:', i, 'x.shape:', x.shape) + + x = x.sum(dim=0) + x = torch.nn.functional.normalize(x, p=2, dim=1) + x = self.conv_activation(x) + + next_layer_repr[vt_row] = x + cur_layer_repr = next_layer_repr + return next_layer_repr + + def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: edges = [] cur_edges = batch.edges for _ in range(len(self.layer_dimensions) - 1): diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index eb31bbc..134b64c 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -192,39 +192,8 @@ def _cat(matrices: List[torch.Tensor]): return res -def _per_layer_required_vertices(data: Data, batch: TrainingBatch, - num_layers: int) -> List[List[EdgeType]]: - - Q = [ - ( batch.vertex_type_row, batch.edges[:, 0] ), - ( batch.vertex_type_column, batch.edges[:, 1] ) - ] - print('Q:', Q) - res = [] - - for _ in range(num_layers): - R = [] - required_rows = [ [] for _ in range(len(data.vertex_types)) ] - - for vertex_type, vertices in Q: - for et in data.edge_types.values(): - if et.vertex_type_row == vertex_type: - required_rows[vertex_type].append(vertices) - indices = et.total_connectivity.indices() - mask = torch.zeros(et.total_connectivity.shape[0]) - mask[vertices] = 1 - mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] - R.append((et.vertex_type_column, - indices[1, mask])) - else: - pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) - - required_rows = [ torch.unique(torch.cat(x)) \ - if len(x) > 0 \ - else None \ - for x in required_rows ] - - res.append(required_rows) - Q = R - - return res +def _mm(a: torch.Tensor, b: torch.Tensor): + if a.is_sparse: + return torch.sparse.mm(a, b) + else: + return torch.mm(a, b) diff --git a/tests/triacontagon/test_model.py b/tests/triacontagon/test_model.py index b887713..c943d8e 100644 --- a/tests/triacontagon/test_model.py +++ b/tests/triacontagon/test_model.py @@ -1,11 +1,10 @@ -from triacontagon.model import _per_layer_required_rows, \ - TrainingBatch -from triacontagon.decode import dedicom_decoder -from triacontagon.data import Data import torch +from triacontagon.model import Model +from triacontagon.data import Data +from triacontagon.decode import dedicom_decoder -def test_per_layer_required_rows_01(): +def test_model_convolve_01(): d = Data() d.add_vertex_type('Gene', 4) d.add_vertex_type('Drug', 5) @@ -15,14 +14,14 @@ def test_per_layer_required_rows_01(): [0, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 1] - ]).to_sparse() ], dedicom_decoder) + ], dtype=torch.float).to_sparse() ], dedicom_decoder) d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ [0, 1, 0, 0, 1], [0, 0, 1, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0] - ]).to_sparse() ], dedicom_decoder) + ], dtype=torch.float).to_sparse() ], dedicom_decoder) d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ [1, 0, 0, 0, 0], @@ -30,11 +29,20 @@ def test_per_layer_required_rows_01(): [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1] - ]).to_sparse() ], dedicom_decoder) + ], dtype=torch.float).to_sparse() ], dedicom_decoder) + + model = Model(d, [9, 32, 64], keep_prob=1.0, + conv_activation = torch.sigmoid, + dec_activation = torch.sigmoid) + + repr_1 = torch.eye(9) + repr_1[4:, 4:] = 0 + repr_2 = torch.eye(9) + repr_2[:4, :4] = 0 - batch = TrainingBatch(0, 1, 0, torch.tensor([ - [0, 1] - ])) + in_layer_repr = [ + repr_1[:4, :].to_sparse(), + repr_2[4:, :].to_sparse() + ] - res = _per_layer_required_rows(d, batch, 5) - print('res:', res) + _ = model.convolve(in_layer_repr)