diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index 69ca74c..d198328 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -124,11 +124,14 @@ class DecagonLayer(Layer): self.next_layer_repr = defaultdict(list) for (nt_row, nt_col), relation_types in self.data.relation_types.items(): + row_convs = [] + col_convs = [] + for rel in relation_types: conv = DropoutGraphConvActivation(self.input_dim[nt_col], self.output_dim[nt_row], rel.adjacency_matrix, self.keep_prob, self.rel_activation) - self.next_layer_repr[nt_row].append((conv, nt_col)) + row_convs.append(conv) if nt_row == nt_col: continue @@ -136,21 +139,27 @@ class DecagonLayer(Layer): conv = DropoutGraphConvActivation(self.input_dim[nt_row], self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), self.keep_prob, self.rel_activation) - self.next_layer_repr[nt_col].append((conv, nt_row)) + col_convs.append(conv) + + self.next_layer_repr[nt_row].append((row_convs, nt_col)) + + if nt_row == nt_col: + continue + + self.next_layer_repr[nt_col].append((col_convs, nt_row)) def __call__(self): prev_layer_repr = self.previous_layer() - next_layer_repr = [None] * len(self.data.node_types) + next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] print('next_layer_repr:', next_layer_repr) for i in range(len(self.data.node_types)): - next_layer_repr[i] = [ - conv(prev_layer_repr[neighbor_type]) \ - for (conv, neighbor_type) in \ - self.next_layer_repr[i] - ] + for convs, neighbor_type in self.next_layer_repr[i]: + convs = [ conv(prev_layer_repr[neighbor_type]) \ + for conv in convs ] + convs = sum(convs) + convs = torch.nn.functional.normalize(convs, p=2, dim=1) + next_layer_repr[i].append(convs) next_layer_repr[i] = sum(next_layer_repr[i]) - next_layer_repr[i] = torch.nn.functional.normalize(next_layer_repr[i], p=2, dim=1) - + next_layer_repr[i] = self.layer_activation(next_layer_repr[i]) print('next_layer_repr:', next_layer_repr) - # next_layer_repr = list(map(sum, next_layer_repr)) return next_layer_repr diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index cc0e5b7..5e7d6b0 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -136,16 +136,21 @@ def test_decagon_layer_03(): x = torch.tensor([-1, 0, 0.5, 1]) assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all() assert len(d_layer.next_layer_repr) == 2 - assert len(d_layer.next_layer_repr[0]) == 2 - assert len(d_layer.next_layer_repr[1]) == 4 - assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation), - d_layer.next_layer_repr[0])) - assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation), - d_layer.next_layer_repr[1])) - assert all(map(lambda a: a[0].output_dim == 32, - d_layer.next_layer_repr[0])) - assert all(map(lambda a: a[0].output_dim == 32, - d_layer.next_layer_repr[1])) + + for i in range(2): + assert len(d_layer.next_layer_repr[i]) == 2 + assert isinstance(d_layer.next_layer_repr[i], list) + assert isinstance(d_layer.next_layer_repr[i][0], tuple) + assert isinstance(d_layer.next_layer_repr[i][0][0], list) + assert isinstance(d_layer.next_layer_repr[i][0][1], int) + assert all([ + isinstance(dgca, DropoutGraphConvActivation) \ + for dgca in d_layer.next_layer_repr[i][0][0] + ]) + assert all([ + dgca.output_dim == 32 \ + for dgca in d_layer.next_layer_repr[i][0][0] + ]) def test_decagon_layer_04(): @@ -166,10 +171,10 @@ def test_decagon_layer_04(): keep_prob=1., rel_activation=lambda x: x, layer_activation=lambda x: x) - assert isinstance(d_layer.next_layer_repr[0][0][0], + assert isinstance(d_layer.next_layer_repr[0][0][0][0], DropoutGraphConvActivation) - weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight + weight = d_layer.next_layer_repr[0][0][0][0].graph_conv.weight assert isinstance(weight, torch.Tensor) assert len(multi_dgca.sparse_dgca) == 1