diff --git a/src/decagon_pytorch/__init__.py b/src/decagon_pytorch/__init__.py index f95e8e8..8271fcc 100644 --- a/src/decagon_pytorch/__init__.py +++ b/src/decagon_pytorch/__init__.py @@ -6,4 +6,4 @@ from .weights import * from .convolve import * from .model import * -from .layer import * +from .layer.decode import * diff --git a/tests/decagon_pytorch/test_decode_dims.py b/tests/decagon_pytorch/test_decode_dims.py new file mode 100644 index 0000000..03d73cc --- /dev/null +++ b/tests/decagon_pytorch/test_decode_dims.py @@ -0,0 +1,106 @@ +from decagon_pytorch.decode import DEDICOMDecoder, \ + DistMultDecoder, \ + BilinearDecoder, \ + InnerProductDecoder +import torch + + +def _common(decoder_class): + decoder = decoder_class(input_dim=10, num_relation_types=1) + inputs = torch.rand((20, 10)) + pred = decoder(inputs, inputs) + + assert isinstance(pred, list) + assert len(pred) == 1 + + assert isinstance(pred[0], torch.Tensor) + assert pred[0].shape == (20, 20) + + + +def test_dedicom_decoder(): + _common(DEDICOMDecoder) + + +def test_dist_mult_decoder(): + _common(DistMultDecoder) + + +def test_bilinear_decoder(): + _common(BilinearDecoder) + + +def test_inner_product_decoder(): + _common(InnerProductDecoder) + + +def test_batch_matrix_multiplication(): + input_dim = 10 + inputs = torch.rand((20, 10)) + + decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1) + out_dec = decoder(inputs, inputs) + + relation = decoder.local_variation[0] + global_interaction = decoder.global_interaction + act = decoder.activation + relation = torch.diag(relation) + product1 = torch.mm(inputs, relation) + product2 = torch.mm(product1, global_interaction) + product3 = torch.mm(product2, relation) + rec = torch.mm(product3, torch.transpose(inputs, 0, 1)) + rec = act(rec) + + print('rec:', rec) + print('out_dec:', out_dec) + + assert torch.all(rec == out_dec[0]) + + +def test_single_prediction_01(): + input_dim = 10 + inputs = torch.rand((20, 10)) + + decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1) + dec_all = decoder(inputs, inputs) + dec_one = decoder(inputs[0:1], inputs[0:1]) + + assert torch.abs(dec_all[0][0, 0] - dec_one[0][0, 0]) < 0.000001 + + +def test_single_prediction_02(): + input_dim = 10 + inputs = torch.rand((20, 10)) + + decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1) + dec_all = decoder(inputs, inputs) + dec_one = decoder(inputs[0:1], inputs[1:2]) + + assert torch.abs(dec_all[0][0, 1] - dec_one[0][0, 0]) < 0.000001 + assert dec_one[0].shape == (1, 1) + + +def test_pairwise_prediction(): + n_nodes = 20 + input_dim = 10 + inputs_row = torch.rand((n_nodes, input_dim)) + inputs_col = torch.rand((n_nodes, input_dim)) + + decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1) + dec_all = decoder(inputs_row, inputs_col) + + relation = torch.diag(decoder.local_variation[0]) + global_interaction = decoder.global_interaction + act = decoder.activation + product1 = torch.mm(inputs_row, relation) + product2 = torch.mm(product1, global_interaction) + product3 = torch.mm(product2, relation) + rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + + assert rec.shape == (n_nodes, 1, 1) + + rec = torch.flatten(rec) + rec = act(rec) + + assert torch.all(torch.abs(rec - torch.diag(dec_all[0])) < 0.000001)