|
|
@@ -17,6 +17,7 @@ import pytest |
|
|
|
import numpy as np
|
|
|
|
from itertools import chain
|
|
|
|
from icosagon.data import RelationType
|
|
|
|
import icosagon.trainprep
|
|
|
|
|
|
|
|
|
|
|
|
def test_train_val_test_split_edges_01():
|
|
|
@@ -203,6 +204,57 @@ def test_prep_rel_two_node_types_asym_01(): |
|
|
|
assert len(prep_rel.edges_back_neg.test) == 1
|
|
|
|
|
|
|
|
|
|
|
|
def test_prepare_relation_type_01():
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
prepare_relation_type(None, None, True)
|
|
|
|
|
|
|
|
adj_mat = torch.rand(10, 10).round()
|
|
|
|
rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
prepare_relation_type(rel, None, True)
|
|
|
|
|
|
|
|
ratios = TrainValTest(.8, .1, .1)
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
prepare_relation_type(None, ratios, True)
|
|
|
|
|
|
|
|
_ = prepare_relation_type(rel, ratios, True)
|
|
|
|
|
|
|
|
|
|
|
|
def test_prepare_relation_type_02(monkeypatch):
|
|
|
|
a = 0
|
|
|
|
b = 0
|
|
|
|
c = 0
|
|
|
|
def fake_prep_rel_one_node_type(*args, **kwargs):
|
|
|
|
nonlocal a
|
|
|
|
a += 1
|
|
|
|
def fake_prep_rel_two_node_types_sym(*args, **kwargs):
|
|
|
|
nonlocal b
|
|
|
|
b += 1
|
|
|
|
def fake_prep_rel_two_node_types_asym(*args, **kwargs):
|
|
|
|
nonlocal c
|
|
|
|
c += 1
|
|
|
|
monkeypatch.setattr(icosagon.trainprep, 'prep_rel_one_node_type',
|
|
|
|
fake_prep_rel_one_node_type)
|
|
|
|
monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_sym',
|
|
|
|
fake_prep_rel_two_node_types_sym)
|
|
|
|
monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_asym',
|
|
|
|
fake_prep_rel_two_node_types_asym)
|
|
|
|
ratios = TrainValTest(.8, .1, .1)
|
|
|
|
rel = RelationType('Dummy Relation', 0, 0, None, None)
|
|
|
|
prepare_relation_type(rel, ratios, False)
|
|
|
|
assert a == 1
|
|
|
|
rel = RelationType('Dummy Relation', 0, 0, None, None)
|
|
|
|
prepare_relation_type(rel, ratios, True)
|
|
|
|
assert a == 2
|
|
|
|
rel = RelationType('Dummy Relation', 0, 1, None, None)
|
|
|
|
prepare_relation_type(rel, ratios, True)
|
|
|
|
assert b == 1
|
|
|
|
rel = RelationType('Dummy Relation', 0, 1, None, None)
|
|
|
|
prepare_relation_type(rel, ratios, False)
|
|
|
|
assert c == 1
|
|
|
|
assert a == 2 and b == 1 and c == 1
|
|
|
|
|
|
|
|
|
|
|
|
# def prepare_relation(r, ratios):
|
|
|
|
# adj_mat = r.adjacency_matrix
|
|
|
|
# adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
|
|
|
|