|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- from decagon_pytorch.data import AdjListData, \
- AdjListRelationType
- import torch
- import pytest
-
-
- def _get_list():
- lst = torch.tensor([
- [0, 1],
- [0, 3],
- [0, 5],
- [0, 7]
- ])
- return lst
-
-
- def test_adj_list_relation_type_01():
- lst = _get_list()
- rel = AdjListRelationType('Test', 0, 0, lst)
- assert torch.all(rel.get_adjacency_list(0, 0) == lst)
-
-
- def test_adj_list_relation_type_02():
- lst = _get_list()
- rel = AdjListRelationType('Test', 0, 1, lst)
- assert torch.all(rel.get_adjacency_list(0, 1) == lst)
- lst_2 = torch.tensor([
- [1, 0],
- [3, 0],
- [5, 0],
- [7, 0]
- ])
- assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
-
-
- def test_adj_list_relation_type_03():
- lst = _get_list()
- lst_2 = torch.tensor([
- [2, 0],
- [4, 0],
- [6, 0],
- [8, 0]
- ])
- rel = AdjListRelationType('Test', 0, 1, lst, lst_2)
- assert torch.all(rel.get_adjacency_list(0, 1) == lst)
- assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
-
-
- def test_adj_list_data_01():
- lst = _get_list()
- d = AdjListData()
- with pytest.raises(AssertionError):
- d.add_relation_type('Test', 0, 1, lst)
- d.add_node_type('Drugs', 5)
- with pytest.raises(AssertionError):
- d.add_relation_type('Test', 0, 0, lst)
- d = AdjListData()
- d.add_node_type('Drugs', 8)
- d.add_relation_type('Test', 0, 0, lst)
-
-
- def test_adj_list_data_02():
- lst = _get_list()
- d = AdjListData()
- d.add_node_type('Drugs', 10)
- d.add_node_type('Proteins', 10)
- d.add_relation_type('Target', 0, 1, lst)
|