IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Selaa lähdekoodia

Add test_fast_conv_layer_01() and test_fast_conv_layer_02().

master
Stanislaw Adaszewski 4 vuotta sitten
vanhempi
commit
b9b97f3dd7
1 muutettua tiedostoa jossa 39 lisäystä ja 2 poistoa
  1. +39
    -2
      tests/icosagon/test_fastconv.py

+ 39
- 2
tests/icosagon/test_fastconv.py Näytä tiedosto

@@ -18,7 +18,7 @@ def _make_symmetric(x: torch.Tensor):
def _symmetric_random(n_rows, n_columns):
return _make_symmetric(torch.rand((n_rows, n_columns),
dtype=torch.float32).round())
dtype=torch.float32).round().to_sparse())
def _some_data_with_interactions():
@@ -28,7 +28,7 @@ def _some_data_with_interactions():
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
fam.add_relation_type('Target',
torch.rand((100, 1000), dtype=torch.float32).round())
torch.rand((100, 1000), dtype=torch.float32).round().to_sparse())
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
fam.add_relation_type('Interaction',
@@ -164,10 +164,47 @@ def test_fast_conv_layer_01():
d = _some_data_with_interactions()
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
seq_1 = torch.nn.Sequential(in_layer, d_layer)
_ = seq_1(None)
conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
seq_2 = torch.nn.Sequential(in_layer, conv_layer)
_ = seq_2(None)
def test_fast_conv_layer_02():
d = _some_data_with_interactions()
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
seq_1 = torch.nn.Sequential(in_layer, d_layer)
out_repr_1 = seq_1(None)
assert len(d_layer.next_layer_repr[0]) == 2
assert len(d_layer.next_layer_repr[1]) == 2
conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
assert len(conv_layer.next_layer_repr[1]) == 2
conv_layer.next_layer_repr[1][0].weights = torch.cat([
d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight,
], dim=1)
conv_layer.next_layer_repr[1][1].weights = torch.cat([
d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight,
d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight,
d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight,
], dim=1)
assert len(conv_layer.next_layer_repr[0]) == 2
conv_layer.next_layer_repr[0][0].weights = torch.cat([
d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight,
], dim=1)
conv_layer.next_layer_repr[0][1].weights = torch.cat([
d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight,
], dim=1)
seq_2 = torch.nn.Sequential(in_layer, conv_layer)
out_repr_2 = seq_2(None)
assert len(out_repr_1) == len(out_repr_2)
for i in range(len(out_repr_1)):
assert torch.all(out_repr_1[i] == out_repr_2[i])

Loading…
Peruuta
Tallenna