|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from decagon_pytorch.decode.cartesian import DEDICOMDecoder, \
- DistMultDecoder, \
- BilinearDecoder, \
- InnerProductDecoder
- import torch
-
-
- def _common(decoder_class):
- decoder = decoder_class(input_dim=10, num_relation_types=1)
- inputs = torch.rand((20, 10))
- pred = decoder(inputs, inputs)
-
- assert isinstance(pred, list)
- assert len(pred) == 1
-
- assert isinstance(pred[0], torch.Tensor)
- assert pred[0].shape == (20, 20)
-
-
-
- def test_dedicom_decoder():
- _common(DEDICOMDecoder)
-
-
- def test_dist_mult_decoder():
- _common(DistMultDecoder)
-
-
- def test_bilinear_decoder():
- _common(BilinearDecoder)
-
-
- def test_inner_product_decoder():
- _common(InnerProductDecoder)
-
-
- def test_batch_matrix_multiplication():
- input_dim = 10
- inputs = torch.rand((20, 10))
-
- decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
- out_dec = decoder(inputs, inputs)
-
- relation = decoder.local_variation[0]
- global_interaction = decoder.global_interaction
- act = decoder.activation
- relation = torch.diag(relation)
- product1 = torch.mm(inputs, relation)
- product2 = torch.mm(product1, global_interaction)
- product3 = torch.mm(product2, relation)
- rec = torch.mm(product3, torch.transpose(inputs, 0, 1))
- rec = act(rec)
-
- print('rec:', rec)
- print('out_dec:', out_dec)
-
- assert torch.all(rec == out_dec[0])
-
-
- def test_single_prediction_01():
- input_dim = 10
- inputs = torch.rand((20, 10))
-
- decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
- dec_all = decoder(inputs, inputs)
- dec_one = decoder(inputs[0:1], inputs[0:1])
-
- assert torch.abs(dec_all[0][0, 0] - dec_one[0][0, 0]) < 0.000001
-
-
- def test_single_prediction_02():
- input_dim = 10
- inputs = torch.rand((20, 10))
-
- decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
- dec_all = decoder(inputs, inputs)
- dec_one = decoder(inputs[0:1], inputs[1:2])
-
- assert torch.abs(dec_all[0][0, 1] - dec_one[0][0, 0]) < 0.000001
- assert dec_one[0].shape == (1, 1)
-
-
- def test_pairwise_prediction():
- n_nodes = 20
- input_dim = 10
- inputs_row = torch.rand((n_nodes, input_dim))
- inputs_col = torch.rand((n_nodes, input_dim))
-
- decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
- dec_all = decoder(inputs_row, inputs_col)
-
- relation = torch.diag(decoder.local_variation[0])
- global_interaction = decoder.global_interaction
- act = decoder.activation
- product1 = torch.mm(inputs_row, relation)
- product2 = torch.mm(product1, global_interaction)
- product3 = torch.mm(product2, relation)
- rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
- inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
-
- assert rec.shape == (n_nodes, 1, 1)
-
- rec = torch.flatten(rec)
- rec = act(rec)
-
- assert torch.all(torch.abs(rec - torch.diag(dec_all[0])) < 0.000001)
|