| @@ -17,6 +17,7 @@ import pytest | |||||
| import numpy as np | import numpy as np | ||||
| from itertools import chain | from itertools import chain | ||||
| from icosagon.data import RelationType | from icosagon.data import RelationType | ||||
| import icosagon.trainprep | |||||
| def test_train_val_test_split_edges_01(): | def test_train_val_test_split_edges_01(): | ||||
| @@ -203,6 +204,57 @@ def test_prep_rel_two_node_types_asym_01(): | |||||
| assert len(prep_rel.edges_back_neg.test) == 1 | assert len(prep_rel.edges_back_neg.test) == 1 | ||||
| def test_prepare_relation_type_01(): | |||||
| with pytest.raises(ValueError): | |||||
| prepare_relation_type(None, None, True) | |||||
| adj_mat = torch.rand(10, 10).round() | |||||
| rel = RelationType('Dummy Relation', 0, 0, adj_mat, None) | |||||
| with pytest.raises(ValueError): | |||||
| prepare_relation_type(rel, None, True) | |||||
| ratios = TrainValTest(.8, .1, .1) | |||||
| with pytest.raises(ValueError): | |||||
| prepare_relation_type(None, ratios, True) | |||||
| _ = prepare_relation_type(rel, ratios, True) | |||||
| def test_prepare_relation_type_02(monkeypatch): | |||||
| a = 0 | |||||
| b = 0 | |||||
| c = 0 | |||||
| def fake_prep_rel_one_node_type(*args, **kwargs): | |||||
| nonlocal a | |||||
| a += 1 | |||||
| def fake_prep_rel_two_node_types_sym(*args, **kwargs): | |||||
| nonlocal b | |||||
| b += 1 | |||||
| def fake_prep_rel_two_node_types_asym(*args, **kwargs): | |||||
| nonlocal c | |||||
| c += 1 | |||||
| monkeypatch.setattr(icosagon.trainprep, 'prep_rel_one_node_type', | |||||
| fake_prep_rel_one_node_type) | |||||
| monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_sym', | |||||
| fake_prep_rel_two_node_types_sym) | |||||
| monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_asym', | |||||
| fake_prep_rel_two_node_types_asym) | |||||
| ratios = TrainValTest(.8, .1, .1) | |||||
| rel = RelationType('Dummy Relation', 0, 0, None, None) | |||||
| prepare_relation_type(rel, ratios, False) | |||||
| assert a == 1 | |||||
| rel = RelationType('Dummy Relation', 0, 0, None, None) | |||||
| prepare_relation_type(rel, ratios, True) | |||||
| assert a == 2 | |||||
| rel = RelationType('Dummy Relation', 0, 1, None, None) | |||||
| prepare_relation_type(rel, ratios, True) | |||||
| assert b == 1 | |||||
| rel = RelationType('Dummy Relation', 0, 1, None, None) | |||||
| prepare_relation_type(rel, ratios, False) | |||||
| assert c == 1 | |||||
| assert a == 2 and b == 1 and c == 1 | |||||
| # def prepare_relation(r, ratios): | # def prepare_relation(r, ratios): | ||||
| # adj_mat = r.adjacency_matrix | # adj_mat = r.adjacency_matrix | ||||
| # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat) | # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat) | ||||