IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Quellcode durchsuchen

Add test_decagon_layer_05().

master
Stanislaw Adaszewski vor 4 Jahren
Ursprung
Commit
4ac42e5b3c
1 geänderte Dateien mit 50 neuen und 0 gelöschten Zeilen
  1. +50
    -0
      tests/decagon_pytorch/test_layer.py

+ 50
- 0
tests/decagon_pytorch/test_layer.py Datei anzeigen

@@ -189,3 +189,53 @@ def test_decagon_layer_04():
assert len(out_d_layer) == 1
assert torch.all(out_d_layer[0] == out_multi_dgca)
def test_decagon_layer_05():
# check if it is equivalent to MultiDGCA, as it should be
# this time for two relations, same edge type
d = Data()
d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
d.add_relation_type('Dummy Relation 2', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
in_layer = OneHotInputLayer(d)
multi_dgca = SparseMultiDGCA([100, 100], 32,
[r.adjacency_matrix for r in d.relation_types[0, 0]],
keep_prob=1., activation=lambda x: x)
d_layer = DecagonLayer(d, in_layer, output_dim=32,
keep_prob=1., rel_activation=lambda x: x,
layer_activation=lambda x: x)
assert all([
isinstance(dgca, DropoutGraphConvActivation) \
for dgca in d_layer.next_layer_repr[0][0][0]
])
weight = [ dgca.graph_conv.weight \
for dgca in d_layer.next_layer_repr[0][0][0] ]
assert all([
isinstance(w, torch.Tensor) \
for w in weight
])
assert len(multi_dgca.sparse_dgca) == 2
for i in range(2):
assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation)
for i in range(2):
multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i]
out_d_layer = d_layer()
x = in_layer()
out_multi_dgca = multi_dgca([ x[0], x[0] ])
assert isinstance(out_d_layer, list)
assert len(out_d_layer) == 1
assert torch.all(out_d_layer[0] == out_multi_dgca)

Laden…
Abbrechen
Speichern