diff --git a/tests/icosagon/test_fastconv.py b/tests/icosagon/test_fastconv.py index 742173d..2003316 100644 --- a/tests/icosagon/test_fastconv.py +++ b/tests/icosagon/test_fastconv.py @@ -18,7 +18,7 @@ def _make_symmetric(x: torch.Tensor): def _symmetric_random(n_rows, n_columns): return _make_symmetric(torch.rand((n_rows, n_columns), - dtype=torch.float32).round()) + dtype=torch.float32).round().to_sparse()) def _some_data_with_interactions(): @@ -28,7 +28,7 @@ def _some_data_with_interactions(): fam = d.add_relation_family('Drug-Gene', 1, 0, True) fam.add_relation_type('Target', - torch.rand((100, 1000), dtype=torch.float32).round()) + torch.rand((100, 1000), dtype=torch.float32).round().to_sparse()) fam = d.add_relation_family('Gene-Gene', 0, 0, True) fam.add_relation_type('Interaction', @@ -164,10 +164,47 @@ def test_fast_conv_layer_01(): d = _some_data_with_interactions() in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) + seq_1 = torch.nn.Sequential(in_layer, d_layer) + _ = seq_1(None) + + conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) + seq_2 = torch.nn.Sequential(in_layer, conv_layer) + _ = seq_2(None) + + +def test_fast_conv_layer_02(): + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) seq_1 = torch.nn.Sequential(in_layer, d_layer) out_repr_1 = seq_1(None) + assert len(d_layer.next_layer_repr[0]) == 2 + assert len(d_layer.next_layer_repr[1]) == 2 + conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) + assert len(conv_layer.next_layer_repr[1]) == 2 + conv_layer.next_layer_repr[1][0].weights = torch.cat([ + d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight, + ], dim=1) + conv_layer.next_layer_repr[1][1].weights = torch.cat([ + d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight, + d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight, + d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight, + ], dim=1) + assert len(conv_layer.next_layer_repr[0]) == 2 + conv_layer.next_layer_repr[0][0].weights = torch.cat([ + d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight, + ], dim=1) + conv_layer.next_layer_repr[0][1].weights = torch.cat([ + d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight, + ], dim=1) + seq_2 = torch.nn.Sequential(in_layer, conv_layer) out_repr_2 = seq_2(None) + + assert len(out_repr_1) == len(out_repr_2) + for i in range(len(out_repr_1)): + assert torch.all(out_repr_1[i] == out_repr_2[i])