| @@ -17,6 +17,7 @@ import pytest | |||
| import numpy as np | |||
| from itertools import chain | |||
| from icosagon.data import RelationType | |||
| import icosagon.trainprep | |||
| def test_train_val_test_split_edges_01(): | |||
| @@ -203,6 +204,57 @@ def test_prep_rel_two_node_types_asym_01(): | |||
| assert len(prep_rel.edges_back_neg.test) == 1 | |||
| def test_prepare_relation_type_01(): | |||
| with pytest.raises(ValueError): | |||
| prepare_relation_type(None, None, True) | |||
| adj_mat = torch.rand(10, 10).round() | |||
| rel = RelationType('Dummy Relation', 0, 0, adj_mat, None) | |||
| with pytest.raises(ValueError): | |||
| prepare_relation_type(rel, None, True) | |||
| ratios = TrainValTest(.8, .1, .1) | |||
| with pytest.raises(ValueError): | |||
| prepare_relation_type(None, ratios, True) | |||
| _ = prepare_relation_type(rel, ratios, True) | |||
| def test_prepare_relation_type_02(monkeypatch): | |||
| a = 0 | |||
| b = 0 | |||
| c = 0 | |||
| def fake_prep_rel_one_node_type(*args, **kwargs): | |||
| nonlocal a | |||
| a += 1 | |||
| def fake_prep_rel_two_node_types_sym(*args, **kwargs): | |||
| nonlocal b | |||
| b += 1 | |||
| def fake_prep_rel_two_node_types_asym(*args, **kwargs): | |||
| nonlocal c | |||
| c += 1 | |||
| monkeypatch.setattr(icosagon.trainprep, 'prep_rel_one_node_type', | |||
| fake_prep_rel_one_node_type) | |||
| monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_sym', | |||
| fake_prep_rel_two_node_types_sym) | |||
| monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_asym', | |||
| fake_prep_rel_two_node_types_asym) | |||
| ratios = TrainValTest(.8, .1, .1) | |||
| rel = RelationType('Dummy Relation', 0, 0, None, None) | |||
| prepare_relation_type(rel, ratios, False) | |||
| assert a == 1 | |||
| rel = RelationType('Dummy Relation', 0, 0, None, None) | |||
| prepare_relation_type(rel, ratios, True) | |||
| assert a == 2 | |||
| rel = RelationType('Dummy Relation', 0, 1, None, None) | |||
| prepare_relation_type(rel, ratios, True) | |||
| assert b == 1 | |||
| rel = RelationType('Dummy Relation', 0, 1, None, None) | |||
| prepare_relation_type(rel, ratios, False) | |||
| assert c == 1 | |||
| assert a == 2 and b == 1 and c == 1 | |||
| # def prepare_relation(r, ratios): | |||
| # adj_mat = r.adjacency_matrix | |||
| # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat) | |||