diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index 3929062..97c7c2a 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -21,6 +21,7 @@ import torch +from .convole import SparseMultiDGCA class InputLayer(torch.nn.Module): @@ -52,9 +53,48 @@ class InputLayer(torch.nn.Module): class DecagonLayer(torch.nn.Module): - def __init__(self, data, **kwargs): + def __init__(self, data, + input_dim, output_dim, + keep_prob=1., + rel_activation=lambda x: x, + layer_activation=torch.nn.functional.relu, + **kwargs): super().__init__(**kwargs) self.data = data + self.input_dim = input_dim + self.output_dim = output_dim + self.keep_prob = keep_prob + self.rel_activation = rel_activation + self.layer_activation = layer_activation + self.convolutions = None + self.build() + + def build(self): + self.convolutions = {} + for key in self.data.relation_types.keys(): + adjacency_matrices = \ + self.data.get_adjacency_matrices(*key) + self.convolutions[key] = SparseMultiDGCA(self.input_dim, + self.output_dim, adjacency_matrices, + self.keep_prob, self.rel_activation) - def __call__(self, previous_layer): - pass + # for node_type_row, node_type_col in enumerate(self.data.node_ + # if rt.node_type_row == i or rt.node_type_col == i: + + def __call__(self, prev_layer_repr): + new_layer_repr = [] + for i, nt in enumerate(self.data.node_types): + new_repr = [] + for key in self.data.relation_types.keys(): + nt_row, nt_col = key + if nt_row != i and nt_col != i: + continue + if nt_row == i: + x = prev_layer_repr[nt_col] + else: + x = prev_layer_repr[nt_row] + conv = self.convolutions[key] + new_repr.append(conv(x)) + new_repr = sum(new_repr) + new_layer_repr.append(new_repr) + return new_layer_repr