| @@ -18,7 +18,7 @@ def _make_symmetric(x: torch.Tensor): | |||||
| def _symmetric_random(n_rows, n_columns): | def _symmetric_random(n_rows, n_columns): | ||||
| return _make_symmetric(torch.rand((n_rows, n_columns), | return _make_symmetric(torch.rand((n_rows, n_columns), | ||||
| dtype=torch.float32).round()) | |||||
| dtype=torch.float32).round().to_sparse()) | |||||
| def _some_data_with_interactions(): | def _some_data_with_interactions(): | ||||
| @@ -28,7 +28,7 @@ def _some_data_with_interactions(): | |||||
| fam = d.add_relation_family('Drug-Gene', 1, 0, True) | fam = d.add_relation_family('Drug-Gene', 1, 0, True) | ||||
| fam.add_relation_type('Target', | fam.add_relation_type('Target', | ||||
| torch.rand((100, 1000), dtype=torch.float32).round()) | |||||
| torch.rand((100, 1000), dtype=torch.float32).round().to_sparse()) | |||||
| fam = d.add_relation_family('Gene-Gene', 0, 0, True) | fam = d.add_relation_family('Gene-Gene', 0, 0, True) | ||||
| fam.add_relation_type('Interaction', | fam.add_relation_type('Interaction', | ||||
| @@ -164,10 +164,47 @@ def test_fast_conv_layer_01(): | |||||
| d = _some_data_with_interactions() | d = _some_data_with_interactions() | ||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) | |||||
| seq_1 = torch.nn.Sequential(in_layer, d_layer) | |||||
| _ = seq_1(None) | |||||
| conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) | |||||
| seq_2 = torch.nn.Sequential(in_layer, conv_layer) | |||||
| _ = seq_2(None) | |||||
| def test_fast_conv_layer_02(): | |||||
| d = _some_data_with_interactions() | |||||
| in_layer = OneHotInputLayer(d) | |||||
| d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) | d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) | ||||
| seq_1 = torch.nn.Sequential(in_layer, d_layer) | seq_1 = torch.nn.Sequential(in_layer, d_layer) | ||||
| out_repr_1 = seq_1(None) | out_repr_1 = seq_1(None) | ||||
| assert len(d_layer.next_layer_repr[0]) == 2 | |||||
| assert len(d_layer.next_layer_repr[1]) == 2 | |||||
| conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) | conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) | ||||
| assert len(conv_layer.next_layer_repr[1]) == 2 | |||||
| conv_layer.next_layer_repr[1][0].weights = torch.cat([ | |||||
| d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight, | |||||
| ], dim=1) | |||||
| conv_layer.next_layer_repr[1][1].weights = torch.cat([ | |||||
| d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight, | |||||
| d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight, | |||||
| d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight, | |||||
| ], dim=1) | |||||
| assert len(conv_layer.next_layer_repr[0]) == 2 | |||||
| conv_layer.next_layer_repr[0][0].weights = torch.cat([ | |||||
| d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight, | |||||
| ], dim=1) | |||||
| conv_layer.next_layer_repr[0][1].weights = torch.cat([ | |||||
| d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight, | |||||
| ], dim=1) | |||||
| seq_2 = torch.nn.Sequential(in_layer, conv_layer) | seq_2 = torch.nn.Sequential(in_layer, conv_layer) | ||||
| out_repr_2 = seq_2(None) | out_repr_2 = seq_2(None) | ||||
| assert len(out_repr_1) == len(out_repr_2) | |||||
| for i in range(len(out_repr_1)): | |||||
| assert torch.all(out_repr_1[i] == out_repr_2[i]) | |||||