From 8c3e94963777a72d94773b8cd914155841e68342 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 26 May 2020 21:36:18 +0200 Subject: [PATCH] Baby steps. --- src/decagon_pytorch/data.py | 12 ++++++++++++ src/decagon_pytorch/layer.py | 21 ++++++++++++++------- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py index 0891c8b..852433f 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data.py @@ -17,6 +17,18 @@ class RelationType(object): self.node_type_column = node_type_column self.adjacency_matrix = adjacency_matrix + def get_adjacency_matrix(node_type_row, node_type_column): + if self.node_type_row == node_type_row and \ + self.node_type_column = node_type_column: + return self.adjacency_matrix + + elif self.node_type_row == node_type_column and \ + self.node_type_column == node_type_row: + return self.adjacency_matrix.transpose(0, 1) + + else: + raise ValueError('Specified row/column types do not correspond to this relation') + class Data(object): def __init__(self): diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index d48de45..bb95dae 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -23,7 +23,9 @@ import torch from .convolve import SparseMultiDGCA from .data import Data -from typing import List, Union +from typing import List, \ + Union, \ + Callable class Layer(torch.nn.Module): @@ -65,15 +67,20 @@ class InputLayer(Layer): class DecagonLayer(Layer): - def __init__(self, data: Data, - input_dim, output_dim, - keep_prob=1., - rel_activation=lambda x: x, - layer_activation=torch.nn.functional.relu, + def __init__(self, + data: Data, + previous_layer: Layer, + output_dim: Union[int, List[int]], + keep_prob: float = 1., + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): + if not isinstance(output_dim, list): + output_dim = [ output_dim ] * len(data.node_types) super().__init__(output_dim, **kwargs) self.data = data - self.input_dim = input_dim + self.previous_layer = previous_layer + self.input_dim = previous_layer.output_dim self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation