IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
3.0KB

  1. from decagon_pytorch.decode.cartesian import DEDICOMDecoder, \
  2. DistMultDecoder, \
  3. BilinearDecoder, \
  4. InnerProductDecoder
  5. import torch
  6. def _common(decoder_class):
  7. decoder = decoder_class(input_dim=10, num_relation_types=1)
  8. inputs = torch.rand((20, 10))
  9. pred = decoder(inputs, inputs)
  10. assert isinstance(pred, list)
  11. assert len(pred) == 1
  12. assert isinstance(pred[0], torch.Tensor)
  13. assert pred[0].shape == (20, 20)
  14. def test_dedicom_decoder():
  15. _common(DEDICOMDecoder)
  16. def test_dist_mult_decoder():
  17. _common(DistMultDecoder)
  18. def test_bilinear_decoder():
  19. _common(BilinearDecoder)
  20. def test_inner_product_decoder():
  21. _common(InnerProductDecoder)
  22. def test_batch_matrix_multiplication():
  23. input_dim = 10
  24. inputs = torch.rand((20, 10))
  25. decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
  26. out_dec = decoder(inputs, inputs)
  27. relation = decoder.local_variation[0]
  28. global_interaction = decoder.global_interaction
  29. act = decoder.activation
  30. relation = torch.diag(relation)
  31. product1 = torch.mm(inputs, relation)
  32. product2 = torch.mm(product1, global_interaction)
  33. product3 = torch.mm(product2, relation)
  34. rec = torch.mm(product3, torch.transpose(inputs, 0, 1))
  35. rec = act(rec)
  36. print('rec:', rec)
  37. print('out_dec:', out_dec)
  38. assert torch.all(rec == out_dec[0])
  39. def test_single_prediction_01():
  40. input_dim = 10
  41. inputs = torch.rand((20, 10))
  42. decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
  43. dec_all = decoder(inputs, inputs)
  44. dec_one = decoder(inputs[0:1], inputs[0:1])
  45. assert torch.abs(dec_all[0][0, 0] - dec_one[0][0, 0]) < 0.000001
  46. def test_single_prediction_02():
  47. input_dim = 10
  48. inputs = torch.rand((20, 10))
  49. decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
  50. dec_all = decoder(inputs, inputs)
  51. dec_one = decoder(inputs[0:1], inputs[1:2])
  52. assert torch.abs(dec_all[0][0, 1] - dec_one[0][0, 0]) < 0.000001
  53. assert dec_one[0].shape == (1, 1)
  54. def test_pairwise_prediction():
  55. n_nodes = 20
  56. input_dim = 10
  57. inputs_row = torch.rand((n_nodes, input_dim))
  58. inputs_col = torch.rand((n_nodes, input_dim))
  59. decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
  60. dec_all = decoder(inputs_row, inputs_col)
  61. relation = torch.diag(decoder.local_variation[0])
  62. global_interaction = decoder.global_interaction
  63. act = decoder.activation
  64. product1 = torch.mm(inputs_row, relation)
  65. product2 = torch.mm(product1, global_interaction)
  66. product3 = torch.mm(product2, relation)
  67. rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
  68. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  69. assert rec.shape == (n_nodes, 1, 1)
  70. rec = torch.flatten(rec)
  71. rec = act(rec)
  72. assert torch.all(torch.abs(rec - torch.diag(dec_all[0])) < 0.000001)