|
|
@@ -4,7 +4,9 @@ from decagon_pytorch.layer import InputLayer, \ |
|
|
|
from decagon_pytorch.data import Data
|
|
|
|
import torch
|
|
|
|
import pytest
|
|
|
|
from decagon_pytorch.convolve import SparseDropoutGraphConvActivation
|
|
|
|
from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
|
|
|
|
SparseMultiDGCA, \
|
|
|
|
DropoutGraphConvActivation
|
|
|
|
|
|
|
|
|
|
|
|
def _some_data():
|
|
|
@@ -136,9 +138,9 @@ def test_decagon_layer_03(): |
|
|
|
assert len(d_layer.next_layer_repr) == 2
|
|
|
|
assert len(d_layer.next_layer_repr[0]) == 2
|
|
|
|
assert len(d_layer.next_layer_repr[1]) == 4
|
|
|
|
assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
|
|
|
|
assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
|
|
|
|
d_layer.next_layer_repr[0]))
|
|
|
|
assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
|
|
|
|
assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
|
|
|
|
d_layer.next_layer_repr[1]))
|
|
|
|
assert all(map(lambda a: a[0].output_dim == 32,
|
|
|
|
d_layer.next_layer_repr[0]))
|
|
|
@@ -147,7 +149,38 @@ def test_decagon_layer_03(): |
|
|
|
|
|
|
|
|
|
|
|
def test_decagon_layer_04():
|
|
|
|
d = _some_data_with_interactions()
|
|
|
|
# check if it is equivalent to MultiDGCA, as it should be
|
|
|
|
|
|
|
|
d = Data()
|
|
|
|
d.add_node_type('Dummy', 100)
|
|
|
|
d.add_relation_type('Dummy Relation', 0, 0,
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
|
|
in_layer = OneHotInputLayer(d)
|
|
|
|
d_layer = DecagonLayer(d, in_layer, output_dim=32)
|
|
|
|
_ = d_layer()
|
|
|
|
|
|
|
|
multi_dgca = SparseMultiDGCA([10], 32,
|
|
|
|
[r.adjacency_matrix for r in d.relation_types[0, 0]],
|
|
|
|
keep_prob=1., activation=lambda x: x)
|
|
|
|
|
|
|
|
d_layer = DecagonLayer(d, in_layer, output_dim=32,
|
|
|
|
keep_prob=1., rel_activation=lambda x: x,
|
|
|
|
layer_activation=lambda x: x)
|
|
|
|
|
|
|
|
assert isinstance(d_layer.next_layer_repr[0][0][0],
|
|
|
|
DropoutGraphConvActivation)
|
|
|
|
|
|
|
|
weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight
|
|
|
|
assert isinstance(weight, torch.Tensor)
|
|
|
|
|
|
|
|
assert len(multi_dgca.sparse_dgca) == 1
|
|
|
|
assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation)
|
|
|
|
|
|
|
|
multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight
|
|
|
|
|
|
|
|
out_d_layer = d_layer()
|
|
|
|
out_multi_dgca = multi_dgca(in_layer())
|
|
|
|
|
|
|
|
assert isinstance(out_d_layer, list)
|
|
|
|
assert len(out_d_layer) == 1
|
|
|
|
|
|
|
|
assert torch.all(out_d_layer[0] == out_multi_dgca)
|