From 2a2aecb3674ad9db0a1662dcf6beed3fb7e806c6 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Sat, 20 Jun 2020 17:50:13 +0200 Subject: [PATCH] Add test_prep_rel_one_node_type_01(). --- tests/icosagon/test_trainprep.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/icosagon/test_trainprep.py b/tests/icosagon/test_trainprep.py index b3a2d1f..2efcf45 100644 --- a/tests/icosagon/test_trainprep.py +++ b/tests/icosagon/test_trainprep.py @@ -8,7 +8,8 @@ from icosagon.trainprep import TrainValTest, \ train_val_test_split_edges, \ get_edges_and_degrees, \ prepare_adj_mat, \ - prepare_relation_type + prepare_relation_type, \ + prep_rel_one_node_type import torch import pytest import numpy as np @@ -110,6 +111,34 @@ def test_prepare_relation_type_01(): _ = prepare_relation_type(r, ratios, False) +def test_prep_rel_one_node_type_01(): + adj_mat = torch.zeros(100) + perm = torch.randperm(100) + adj_mat[perm[:10]] = 1 + adj_mat = adj_mat.view(10, 10) + rel = RelationType('Dummy Relation', 0, 0, adj_mat, None) + ratios = TrainValTest(.8, .1, .1) + prep_rel = prep_rel_one_node_type(rel, ratios) + assert prep_rel.name == rel.name + assert prep_rel.node_type_row == rel.node_type_row + assert prep_rel.node_type_column == rel.node_type_column + assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape + assert prep_rel.adjacency_matrix_backward is None + assert len(prep_rel.edges_pos.train) == 8 + assert len(prep_rel.edges_pos.val) == 1 + assert len(prep_rel.edges_pos.test) == 1 + assert len(prep_rel.edges_neg.train) == 8 + assert len(prep_rel.edges_neg.val) == 1 + assert len(prep_rel.edges_neg.test) == 1 + + assert len(prep_rel.edges_back_pos.train) == 0 + assert len(prep_rel.edges_back_pos.val) == 0 + assert len(prep_rel.edges_back_pos.test) == 0 + assert len(prep_rel.edges_back_neg.train) == 0 + assert len(prep_rel.edges_back_neg.val) == 0 + assert len(prep_rel.edges_back_neg.test) == 0 + + # def prepare_relation(r, ratios): # adj_mat = r.adjacency_matrix