IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
1.7KB

  1. from decagon_pytorch.data import AdjListData, \
  2. AdjListRelationType
  3. import torch
  4. import pytest
  5. def _get_list():
  6. lst = torch.tensor([
  7. [0, 1],
  8. [0, 3],
  9. [0, 5],
  10. [0, 7]
  11. ])
  12. return lst
  13. def test_adj_list_relation_type_01():
  14. lst = _get_list()
  15. rel = AdjListRelationType('Test', 0, 0, lst)
  16. assert torch.all(rel.get_adjacency_list(0, 0) == lst)
  17. def test_adj_list_relation_type_02():
  18. lst = _get_list()
  19. rel = AdjListRelationType('Test', 0, 1, lst)
  20. assert torch.all(rel.get_adjacency_list(0, 1) == lst)
  21. lst_2 = torch.tensor([
  22. [1, 0],
  23. [3, 0],
  24. [5, 0],
  25. [7, 0]
  26. ])
  27. assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
  28. def test_adj_list_relation_type_03():
  29. lst = _get_list()
  30. lst_2 = torch.tensor([
  31. [2, 0],
  32. [4, 0],
  33. [6, 0],
  34. [8, 0]
  35. ])
  36. rel = AdjListRelationType('Test', 0, 1, lst, lst_2)
  37. assert torch.all(rel.get_adjacency_list(0, 1) == lst)
  38. assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
  39. def test_adj_list_data_01():
  40. lst = _get_list()
  41. d = AdjListData()
  42. with pytest.raises(AssertionError):
  43. d.add_relation_type('Test', 0, 1, lst)
  44. d.add_node_type('Drugs', 5)
  45. with pytest.raises(AssertionError):
  46. d.add_relation_type('Test', 0, 0, lst)
  47. d = AdjListData()
  48. d.add_node_type('Drugs', 8)
  49. d.add_relation_type('Test', 0, 0, lst)
  50. def test_adj_list_data_02():
  51. lst = _get_list()
  52. d = AdjListData()
  53. d.add_node_type('Drugs', 10)
  54. d.add_node_type('Proteins', 10)
  55. d.add_relation_type('Target', 0, 1, lst)