From 69dd9a49e2c46b6df68b4629ef0f2845261cc1dd Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 27 May 2020 18:50:52 +0200 Subject: [PATCH] Dummy run of DecagonLayer seems to work, that's something. --- src/decagon_pytorch/data.py | 2 ++ src/decagon_pytorch/layer.py | 56 +++++++++++++++++++++++++---- tests/decagon_pytorch/test_layer.py | 8 +++++ 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py index 381d15d..2418371 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data.py @@ -46,6 +46,8 @@ class Data(object): if node_type_row >= n or node_type_column >= n: raise ValueError('Node type index out of bounds, add node type first') key = (node_type_row, node_type_column) + if adjacency_matrix is not None and not adjacency_matrix.is_sparse: + adjacency_matrix = adjacency_matrix.to_sparse() self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix)) # _ = self.decoder_types[(node_type_row, node_type_column)] diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index 6cd91c6..b1f5d70 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -30,9 +30,13 @@ from collections import defaultdict class Layer(torch.nn.Module): - def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None: + def __init__(self, + output_dim: Union[int, List[int]], + is_sparse: bool, + **kwargs) -> None: super().__init__(**kwargs) self.output_dim = output_dim + self.is_sparse = is_sparse class InputLayer(Layer): @@ -42,7 +46,7 @@ class InputLayer(Layer): if not isinstance(output_dim, list): output_dim = [output_dim,] * len(data.node_types) - super().__init__(output_dim, **kwargs) + super().__init__(output_dim, is_sparse=False, **kwargs) self.data = data self.node_reps = None self.build() @@ -67,6 +71,34 @@ class InputLayer(Layer): return s.strip() +class OneHotInputLayer(Layer): + def __init__(self, data: Data, **kwargs) -> None: + output_dim = [ a.count for a in data.node_types ] + super().__init__(output_dim, is_sparse=True, **kwargs) + self.data = data + self.node_reps = None + self.build() + + def build(self) -> None: + self.node_reps = [] + for i, nt in enumerate(self.data.node_types): + reps = torch.eye(nt.count).to_sparse() + reps = torch.nn.Parameter(reps) + self.register_parameter('node_reps[%d]' % i, reps) + self.node_reps.append(reps) + + def forward(self) -> List[torch.nn.Parameter]: + return self.node_reps + + def __repr__(self) -> str: + s = '' + s += 'One-hot GNN input layer\n' + s += ' # of node types: %d\n' % len(self.data.node_types) + for nt in self.data.node_types: + s += ' - %s (%d)\n' % (nt.name, nt.count) + return s.strip() + + class DecagonLayer(Layer): def __init__(self, data: Data, @@ -78,7 +110,7 @@ class DecagonLayer(Layer): **kwargs): if not isinstance(output_dim, list): output_dim = [ output_dim ] * len(data.node_types) - super().__init__(output_dim, **kwargs) + super().__init__(output_dim, is_sparse=False, **kwargs) self.data = data self.previous_layer = previous_layer self.input_dim = previous_layer.output_dim @@ -98,6 +130,9 @@ class DecagonLayer(Layer): self.keep_prob, self.rel_activation) self.next_layer_repr[nt_row].append((conv, nt_col)) + if nt_row == nt_col: + continue + conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row], self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), self.keep_prob, self.rel_activation) @@ -105,9 +140,16 @@ class DecagonLayer(Layer): def __call__(self): prev_layer_repr = self.previous_layer() - next_layer_repr = self.next_layer_repr + next_layer_repr = [None] * len(self.data.node_types) + print('next_layer_repr:', next_layer_repr) for i in range(len(self.data.node_types)): - next_layer_repr[i] = map(lambda conv, neighbor_type: \ - conv(prev_layer_repr[neighbor_type]), next_layer_repr[i]) - next_layer_repr = list(map(sum, next_layer_repr)) + next_layer_repr[i] = [ + conv(prev_layer_repr[neighbor_type]) \ + for (conv, neighbor_type) in \ + self.next_layer_repr[i] + ] + next_layer_repr[i] = sum(next_layer_repr[i]) + + print('next_layer_repr:', next_layer_repr) + # next_layer_repr = list(map(sum, next_layer_repr)) return next_layer_repr diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index 873ae6b..4dee8b1 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -1,4 +1,5 @@ from decagon_pytorch.layer import InputLayer, \ + OneHotInputLayer, \ DecagonLayer from decagon_pytorch.data import Data import torch @@ -74,3 +75,10 @@ def test_decagon_layer_01(): d = _some_data_with_interactions() in_layer = InputLayer(d) d_layer = DecagonLayer(d, in_layer, output_dim=32) + + +def test_decagon_layer_02(): + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(d, in_layer, output_dim=32) + _ = d_layer() # dummy call