From 2cdc76fac7a467e3726f6f20d9f3a9b6808f1535 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 26 May 2020 21:17:26 +0200 Subject: [PATCH] Prepare to make Decagon layer work. --- src/decagon_pytorch/layer.py | 39 +++++++++++++++++------------ tests/decagon_pytorch/test_layer.py | 10 ++++---- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index f21c182..d48de45 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -22,51 +22,58 @@ import torch from .convolve import SparseMultiDGCA +from .data import Data +from typing import List, Union -class InputLayer(torch.nn.Module): - def __init__(self, data, dimensionality=None, **kwargs): +class Layer(torch.nn.Module): + def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None: super().__init__(**kwargs) - self.data = data - dimensionality = dimensionality or \ + self.output_dim = output_dim + + +class InputLayer(Layer): + def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None: + output_dim = output_dim or \ list(map(lambda a: a.count, data.node_types)) - if not isinstance(dimensionality, list): - dimensionality = [dimensionality,] * len(self.data.node_types) - self.dimensionality = dimensionality + if not isinstance(output_dim, list): + output_dim = [output_dim,] * len(data.node_types) + + super().__init__(output_dim, **kwargs) + self.data = data self.node_reps = None self.build() - def build(self): + def build(self) -> None: self.node_reps = [] for i, nt in enumerate(self.data.node_types): - reps = torch.rand(nt.count, self.dimensionality[i]) + reps = torch.rand(nt.count, self.output_dim[i]) reps = torch.nn.Parameter(reps) self.register_parameter('node_reps[%d]' % i, reps) self.node_reps.append(reps) - def forward(self): + def forward(self) -> List[torch.nn.Parameter]: return self.node_reps - def __repr__(self): + def __repr__(self) -> str: s = '' - s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality + s += 'GNN input layer with output_dim: %s\n' % self.output_dim s += ' # of node types: %d\n' % len(self.data.node_types) for nt in self.data.node_types: s += ' - %s (%d)\n' % (nt.name, nt.count) return s.strip() -class DecagonLayer(torch.nn.Module): - def __init__(self, data, +class DecagonLayer(Layer): + def __init__(self, data: Data, input_dim, output_dim, keep_prob=1., rel_activation=lambda x: x, layer_activation=torch.nn.functional.relu, **kwargs): - super().__init__(**kwargs) + super().__init__(output_dim, **kwargs) self.data = data self.input_dim = input_dim - self.output_dim = output_dim self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index c171e1c..1497fe8 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -36,12 +36,12 @@ def _some_data_with_interactions(): def test_input_layer_01(): d = _some_data() - for dimensionality in [32, 64, 128]: - layer = InputLayer(d, dimensionality) - assert layer.dimensionality[0] == dimensionality + for output_dim in [32, 64, 128]: + layer = InputLayer(d, output_dim) + assert layer.output_dim[0] == output_dim assert len(layer.node_reps) == 2 - assert layer.node_reps[0].shape == (1000, dimensionality) - assert layer.node_reps[1].shape == (100, dimensionality) + assert layer.node_reps[0].shape == (1000, output_dim) + assert layer.node_reps[1].shape == (100, output_dim) assert layer.data == d