diff --git a/src/decagon_pytorch/convolve.py b/src/decagon_pytorch/convolve.py index 10475c6..9087f6d 100644 --- a/src/decagon_pytorch/convolve.py +++ b/src/decagon_pytorch/convolve.py @@ -195,9 +195,12 @@ class SparseDropoutGraphConvActivation(torch.nn.Module): activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, **kwargs) -> None: super().__init__(**kwargs) - self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) + self.input_dim = input_dim + self.output_dim = output_dim + self.adjacency_matrix = adjacency_matrix self.keep_prob = keep_prob self.activation = activation + self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix) def forward(self, x: torch.Tensor) -> torch.Tensor: x = dropout_sparse(x, self.keep_prob) diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index 16354ab..cfb5063 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -4,6 +4,7 @@ from decagon_pytorch.layer import InputLayer, \ from decagon_pytorch.data import Data import torch import pytest +from decagon_pytorch.convolve import SparseDropoutGraphConvActivation def _some_data(): @@ -121,4 +122,32 @@ def test_decagon_layer_02(): def test_decagon_layer_03(): - pass + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(d, in_layer, output_dim=32) + assert d_layer.data == d + assert d_layer.previous_layer == in_layer + assert d_layer.input_dim == [ 1000, 100 ] + assert not d_layer.is_sparse + assert d_layer.keep_prob == 1. + assert d_layer.rel_activation(0.5) == 0.5 + x = torch.tensor([-1, 0, 0.5, 1]) + assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all() + assert len(d_layer.next_layer_repr) == 2 + assert len(d_layer.next_layer_repr[0]) == 2 + assert len(d_layer.next_layer_repr[1]) == 4 + assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation), + d_layer.next_layer_repr[0])) + assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation), + d_layer.next_layer_repr[1])) + assert all(map(lambda a: a[0].output_dim == 32, + d_layer.next_layer_repr[0])) + assert all(map(lambda a: a[0].output_dim == 32, + d_layer.next_layer_repr[1])) + + +def test_decagon_layer_04(): + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(d, in_layer, output_dim=32) + _ = d_layer()