|
|
@@ -0,0 +1,106 @@ |
|
|
|
from decagon_pytorch.decode import DEDICOMDecoder, \
|
|
|
|
DistMultDecoder, \
|
|
|
|
BilinearDecoder, \
|
|
|
|
InnerProductDecoder
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def _common(decoder_class):
|
|
|
|
decoder = decoder_class(input_dim=10, num_relation_types=1)
|
|
|
|
inputs = torch.rand((20, 10))
|
|
|
|
pred = decoder(inputs, inputs)
|
|
|
|
|
|
|
|
assert isinstance(pred, list)
|
|
|
|
assert len(pred) == 1
|
|
|
|
|
|
|
|
assert isinstance(pred[0], torch.Tensor)
|
|
|
|
assert pred[0].shape == (20, 20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dedicom_decoder():
|
|
|
|
_common(DEDICOMDecoder)
|
|
|
|
|
|
|
|
|
|
|
|
def test_dist_mult_decoder():
|
|
|
|
_common(DistMultDecoder)
|
|
|
|
|
|
|
|
|
|
|
|
def test_bilinear_decoder():
|
|
|
|
_common(BilinearDecoder)
|
|
|
|
|
|
|
|
|
|
|
|
def test_inner_product_decoder():
|
|
|
|
_common(InnerProductDecoder)
|
|
|
|
|
|
|
|
|
|
|
|
def test_batch_matrix_multiplication():
|
|
|
|
input_dim = 10
|
|
|
|
inputs = torch.rand((20, 10))
|
|
|
|
|
|
|
|
decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
|
|
|
|
out_dec = decoder(inputs, inputs)
|
|
|
|
|
|
|
|
relation = decoder.local_variation[0]
|
|
|
|
global_interaction = decoder.global_interaction
|
|
|
|
act = decoder.activation
|
|
|
|
relation = torch.diag(relation)
|
|
|
|
product1 = torch.mm(inputs, relation)
|
|
|
|
product2 = torch.mm(product1, global_interaction)
|
|
|
|
product3 = torch.mm(product2, relation)
|
|
|
|
rec = torch.mm(product3, torch.transpose(inputs, 0, 1))
|
|
|
|
rec = act(rec)
|
|
|
|
|
|
|
|
print('rec:', rec)
|
|
|
|
print('out_dec:', out_dec)
|
|
|
|
|
|
|
|
assert torch.all(rec == out_dec[0])
|
|
|
|
|
|
|
|
|
|
|
|
def test_single_prediction_01():
|
|
|
|
input_dim = 10
|
|
|
|
inputs = torch.rand((20, 10))
|
|
|
|
|
|
|
|
decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
|
|
|
|
dec_all = decoder(inputs, inputs)
|
|
|
|
dec_one = decoder(inputs[0:1], inputs[0:1])
|
|
|
|
|
|
|
|
assert torch.abs(dec_all[0][0, 0] - dec_one[0][0, 0]) < 0.000001
|
|
|
|
|
|
|
|
|
|
|
|
def test_single_prediction_02():
|
|
|
|
input_dim = 10
|
|
|
|
inputs = torch.rand((20, 10))
|
|
|
|
|
|
|
|
decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
|
|
|
|
dec_all = decoder(inputs, inputs)
|
|
|
|
dec_one = decoder(inputs[0:1], inputs[1:2])
|
|
|
|
|
|
|
|
assert torch.abs(dec_all[0][0, 1] - dec_one[0][0, 0]) < 0.000001
|
|
|
|
assert dec_one[0].shape == (1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
def test_pairwise_prediction():
|
|
|
|
n_nodes = 20
|
|
|
|
input_dim = 10
|
|
|
|
inputs_row = torch.rand((n_nodes, input_dim))
|
|
|
|
inputs_col = torch.rand((n_nodes, input_dim))
|
|
|
|
|
|
|
|
decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
|
|
|
|
dec_all = decoder(inputs_row, inputs_col)
|
|
|
|
|
|
|
|
relation = torch.diag(decoder.local_variation[0])
|
|
|
|
global_interaction = decoder.global_interaction
|
|
|
|
act = decoder.activation
|
|
|
|
product1 = torch.mm(inputs_row, relation)
|
|
|
|
product2 = torch.mm(product1, global_interaction)
|
|
|
|
product3 = torch.mm(product2, relation)
|
|
|
|
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
|
|
assert rec.shape == (n_nodes, 1, 1)
|
|
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
rec = act(rec)
|
|
|
|
|
|
|
|
assert torch.all(torch.abs(rec - torch.diag(dec_all[0])) < 0.000001)
|