| @@ -18,7 +18,7 @@ def _make_symmetric(x: torch.Tensor): | |||
| def _symmetric_random(n_rows, n_columns): | |||
| return _make_symmetric(torch.rand((n_rows, n_columns), | |||
| dtype=torch.float32).round()) | |||
| dtype=torch.float32).round().to_sparse()) | |||
| def _some_data_with_interactions(): | |||
| @@ -28,7 +28,7 @@ def _some_data_with_interactions(): | |||
| fam = d.add_relation_family('Drug-Gene', 1, 0, True) | |||
| fam.add_relation_type('Target', | |||
| torch.rand((100, 1000), dtype=torch.float32).round()) | |||
| torch.rand((100, 1000), dtype=torch.float32).round().to_sparse()) | |||
| fam = d.add_relation_family('Gene-Gene', 0, 0, True) | |||
| fam.add_relation_type('Interaction', | |||
| @@ -164,10 +164,47 @@ def test_fast_conv_layer_01(): | |||
| d = _some_data_with_interactions() | |||
| in_layer = OneHotInputLayer(d) | |||
| d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) | |||
| seq_1 = torch.nn.Sequential(in_layer, d_layer) | |||
| _ = seq_1(None) | |||
| conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) | |||
| seq_2 = torch.nn.Sequential(in_layer, conv_layer) | |||
| _ = seq_2(None) | |||
| def test_fast_conv_layer_02(): | |||
| d = _some_data_with_interactions() | |||
| in_layer = OneHotInputLayer(d) | |||
| d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) | |||
| seq_1 = torch.nn.Sequential(in_layer, d_layer) | |||
| out_repr_1 = seq_1(None) | |||
| assert len(d_layer.next_layer_repr[0]) == 2 | |||
| assert len(d_layer.next_layer_repr[1]) == 2 | |||
| conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) | |||
| assert len(conv_layer.next_layer_repr[1]) == 2 | |||
| conv_layer.next_layer_repr[1][0].weights = torch.cat([ | |||
| d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight, | |||
| ], dim=1) | |||
| conv_layer.next_layer_repr[1][1].weights = torch.cat([ | |||
| d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight, | |||
| d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight, | |||
| d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight, | |||
| ], dim=1) | |||
| assert len(conv_layer.next_layer_repr[0]) == 2 | |||
| conv_layer.next_layer_repr[0][0].weights = torch.cat([ | |||
| d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight, | |||
| ], dim=1) | |||
| conv_layer.next_layer_repr[0][1].weights = torch.cat([ | |||
| d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight, | |||
| ], dim=1) | |||
| seq_2 = torch.nn.Sequential(in_layer, conv_layer) | |||
| out_repr_2 = seq_2(None) | |||
| assert len(out_repr_1) == len(out_repr_2) | |||
| for i in range(len(out_repr_1)): | |||
| assert torch.all(out_repr_1[i] == out_repr_2[i]) | |||