diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index 76f2721..6cd91c6 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -91,16 +91,17 @@ class DecagonLayer(Layer): def build(self): self.next_layer_repr = defaultdict(list) - for (nt_row, nt_col), rel in self.data.relation_types.items(): - conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col], - self.output_dim[nt_row], rel.adjacency_matrix, - self.keep_prob, self.rel_activation) - self.next_layer_repr[nt_row].append((conv, nt_col)) - - conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row], - self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), - self.keep_prob, self.rel_activation) - self.next_layer_repr[nt_col].append((conv, nt_row)) + for (nt_row, nt_col), relation_types in self.data.relation_types.items(): + for rel in relation_types: + conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col], + self.output_dim[nt_row], rel.adjacency_matrix, + self.keep_prob, self.rel_activation) + self.next_layer_repr[nt_row].append((conv, nt_col)) + + conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row], + self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), + self.keep_prob, self.rel_activation) + self.next_layer_repr[nt_col].append((conv, nt_row)) def __call__(self): prev_layer_repr = self.previous_layer() diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index 1497fe8..873ae6b 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -70,8 +70,7 @@ def test_input_layer_03(): assert layer.node_reps[1].device == device -@pytest.mark.skip() def test_decagon_layer_01(): d = _some_data_with_interactions() in_layer = InputLayer(d) - d_layer = DecagonLayer(in_layer, output_dim=32) + d_layer = DecagonLayer(d, in_layer, output_dim=32)