IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
1.8KB

  1. import decagon_pytorch.decode.cartesian as cart
  2. import decagon_pytorch.decode.pairwise as pair
  3. import torch
  4. def _common(cart_class, pair_class):
  5. input_dim = 10
  6. n_nodes = 20
  7. num_relation_types = 7
  8. inputs_row = torch.rand((n_nodes, input_dim))
  9. inputs_col = torch.rand((n_nodes, input_dim))
  10. cart_dec = cart_class(input_dim=input_dim,
  11. num_relation_types=num_relation_types)
  12. pair_dec = pair_class(input_dim=input_dim,
  13. num_relation_types=num_relation_types)
  14. if isinstance(cart_dec, cart.DEDICOMDecoder):
  15. pair_dec.global_interaction = cart_dec.global_interaction
  16. pair_dec.local_variation = cart_dec.local_variation
  17. elif isinstance(cart_dec, cart.InnerProductDecoder):
  18. pass
  19. else:
  20. pair_dec.relation = cart_dec.relation
  21. cart_pred = cart_dec(inputs_row, inputs_col)
  22. pair_pred = pair_dec(inputs_row, inputs_col)
  23. assert isinstance(cart_pred, list)
  24. assert isinstance(pair_pred, list)
  25. assert len(cart_pred) == len(pair_pred)
  26. assert len(cart_pred) == num_relation_types
  27. for i in range(num_relation_types):
  28. assert isinstance(cart_pred[i], torch.Tensor)
  29. assert isinstance(pair_pred[i], torch.Tensor)
  30. assert cart_pred[i].shape == (n_nodes, n_nodes)
  31. assert pair_pred[i].shape == (n_nodes,)
  32. assert torch.all(torch.abs(pair_pred[i] - torch.diag(cart_pred[i])) < 0.000001)
  33. def test_dedicom_decoder():
  34. _common(cart.DEDICOMDecoder, pair.DEDICOMDecoder)
  35. def test_dist_mult_decoder():
  36. _common(cart.DistMultDecoder, pair.DistMultDecoder)
  37. def test_bilinear_decoder():
  38. _common(cart.BilinearDecoder, pair.BilinearDecoder)
  39. def test_inner_product_decoder():
  40. _common(cart.InnerProductDecoder, pair.InnerProductDecoder)