|
|
@@ -18,7 +18,7 @@ def _make_symmetric(x: torch.Tensor): |
|
|
|
|
|
|
|
def _symmetric_random(n_rows, n_columns):
|
|
|
|
return _make_symmetric(torch.rand((n_rows, n_columns),
|
|
|
|
dtype=torch.float32).round())
|
|
|
|
dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
|
|
|
|
|
|
def _some_data_with_interactions():
|
|
|
@@ -28,7 +28,7 @@ def _some_data_with_interactions(): |
|
|
|
|
|
|
|
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
|
|
|
|
fam.add_relation_type('Target',
|
|
|
|
torch.rand((100, 1000), dtype=torch.float32).round())
|
|
|
|
torch.rand((100, 1000), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
|
|
|
|
fam.add_relation_type('Interaction',
|
|
|
@@ -164,10 +164,47 @@ def test_fast_conv_layer_01(): |
|
|
|
d = _some_data_with_interactions()
|
|
|
|
in_layer = OneHotInputLayer(d)
|
|
|
|
|
|
|
|
d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
|
|
|
|
seq_1 = torch.nn.Sequential(in_layer, d_layer)
|
|
|
|
_ = seq_1(None)
|
|
|
|
|
|
|
|
conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
|
|
|
|
seq_2 = torch.nn.Sequential(in_layer, conv_layer)
|
|
|
|
_ = seq_2(None)
|
|
|
|
|
|
|
|
|
|
|
|
def test_fast_conv_layer_02():
|
|
|
|
d = _some_data_with_interactions()
|
|
|
|
in_layer = OneHotInputLayer(d)
|
|
|
|
|
|
|
|
d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
|
|
|
|
seq_1 = torch.nn.Sequential(in_layer, d_layer)
|
|
|
|
out_repr_1 = seq_1(None)
|
|
|
|
|
|
|
|
assert len(d_layer.next_layer_repr[0]) == 2
|
|
|
|
assert len(d_layer.next_layer_repr[1]) == 2
|
|
|
|
|
|
|
|
conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
|
|
|
|
assert len(conv_layer.next_layer_repr[1]) == 2
|
|
|
|
conv_layer.next_layer_repr[1][0].weights = torch.cat([
|
|
|
|
d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight,
|
|
|
|
], dim=1)
|
|
|
|
conv_layer.next_layer_repr[1][1].weights = torch.cat([
|
|
|
|
d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight,
|
|
|
|
d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight,
|
|
|
|
d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight,
|
|
|
|
], dim=1)
|
|
|
|
assert len(conv_layer.next_layer_repr[0]) == 2
|
|
|
|
conv_layer.next_layer_repr[0][0].weights = torch.cat([
|
|
|
|
d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight,
|
|
|
|
], dim=1)
|
|
|
|
conv_layer.next_layer_repr[0][1].weights = torch.cat([
|
|
|
|
d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight,
|
|
|
|
], dim=1)
|
|
|
|
|
|
|
|
seq_2 = torch.nn.Sequential(in_layer, conv_layer)
|
|
|
|
out_repr_2 = seq_2(None)
|
|
|
|
|
|
|
|
assert len(out_repr_1) == len(out_repr_2)
|
|
|
|
for i in range(len(out_repr_1)):
|
|
|
|
assert torch.all(out_repr_1[i] == out_repr_2[i])
|