diff --git a/tests/icosagon/test_trainprep.py b/tests/icosagon/test_trainprep.py index eb78859..73ec624 100644 --- a/tests/icosagon/test_trainprep.py +++ b/tests/icosagon/test_trainprep.py @@ -17,6 +17,7 @@ import pytest import numpy as np from itertools import chain from icosagon.data import RelationType +import icosagon.trainprep def test_train_val_test_split_edges_01(): @@ -203,6 +204,57 @@ def test_prep_rel_two_node_types_asym_01(): assert len(prep_rel.edges_back_neg.test) == 1 +def test_prepare_relation_type_01(): + with pytest.raises(ValueError): + prepare_relation_type(None, None, True) + + adj_mat = torch.rand(10, 10).round() + rel = RelationType('Dummy Relation', 0, 0, adj_mat, None) + with pytest.raises(ValueError): + prepare_relation_type(rel, None, True) + + ratios = TrainValTest(.8, .1, .1) + with pytest.raises(ValueError): + prepare_relation_type(None, ratios, True) + + _ = prepare_relation_type(rel, ratios, True) + + +def test_prepare_relation_type_02(monkeypatch): + a = 0 + b = 0 + c = 0 + def fake_prep_rel_one_node_type(*args, **kwargs): + nonlocal a + a += 1 + def fake_prep_rel_two_node_types_sym(*args, **kwargs): + nonlocal b + b += 1 + def fake_prep_rel_two_node_types_asym(*args, **kwargs): + nonlocal c + c += 1 + monkeypatch.setattr(icosagon.trainprep, 'prep_rel_one_node_type', + fake_prep_rel_one_node_type) + monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_sym', + fake_prep_rel_two_node_types_sym) + monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_asym', + fake_prep_rel_two_node_types_asym) + ratios = TrainValTest(.8, .1, .1) + rel = RelationType('Dummy Relation', 0, 0, None, None) + prepare_relation_type(rel, ratios, False) + assert a == 1 + rel = RelationType('Dummy Relation', 0, 0, None, None) + prepare_relation_type(rel, ratios, True) + assert a == 2 + rel = RelationType('Dummy Relation', 0, 1, None, None) + prepare_relation_type(rel, ratios, True) + assert b == 1 + rel = RelationType('Dummy Relation', 0, 1, None, None) + prepare_relation_type(rel, ratios, False) + assert c == 1 + assert a == 2 and b == 1 and c == 1 + + # def prepare_relation(r, ratios): # adj_mat = r.adjacency_matrix # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)