| @@ -9,7 +9,8 @@ from icosagon.trainprep import TrainValTest, \ | |||||
| get_edges_and_degrees, \ | get_edges_and_degrees, \ | ||||
| prepare_adj_mat, \ | prepare_adj_mat, \ | ||||
| prepare_relation_type, \ | prepare_relation_type, \ | ||||
| prep_rel_one_node_type | |||||
| prep_rel_one_node_type, \ | |||||
| prep_rel_two_node_types_sym | |||||
| import torch | import torch | ||||
| import pytest | import pytest | ||||
| import numpy as np | import numpy as np | ||||
| @@ -139,6 +140,33 @@ def test_prep_rel_one_node_type_01(): | |||||
| assert len(prep_rel.edges_back_neg.test) == 0 | assert len(prep_rel.edges_back_neg.test) == 0 | ||||
| def test_prep_rel_two_node_types_sym_01(): | |||||
| adj_mat = torch.zeros(200) | |||||
| perm = torch.randperm(100) | |||||
| adj_mat[perm[:10]] = 1 | |||||
| adj_mat = adj_mat.view(10, 20) | |||||
| rel = RelationType('Dummy Relation', 0, 1, adj_mat, None) | |||||
| ratios = TrainValTest(.8, .1, .1) | |||||
| prep_rel = prep_rel_two_node_types_sym(rel, ratios) | |||||
| assert prep_rel.name == rel.name | |||||
| assert prep_rel.node_type_row == rel.node_type_row | |||||
| assert prep_rel.node_type_column == rel.node_type_column | |||||
| assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape | |||||
| assert prep_rel.adjacency_matrix_backward.shape == (20, 10) | |||||
| assert len(prep_rel.edges_pos.train) == 8 | |||||
| assert len(prep_rel.edges_pos.val) == 1 | |||||
| assert len(prep_rel.edges_pos.test) == 1 | |||||
| assert len(prep_rel.edges_neg.train) == 8 | |||||
| assert len(prep_rel.edges_neg.val) == 1 | |||||
| assert len(prep_rel.edges_neg.test) == 1 | |||||
| assert len(prep_rel.edges_back_pos.train) == 0 | |||||
| assert len(prep_rel.edges_back_pos.val) == 0 | |||||
| assert len(prep_rel.edges_back_pos.test) == 0 | |||||
| assert len(prep_rel.edges_back_neg.train) == 0 | |||||
| assert len(prep_rel.edges_back_neg.val) == 0 | |||||
| assert len(prep_rel.edges_back_neg.test) == 0 | |||||
| # def prepare_relation(r, ratios): | # def prepare_relation(r, ratios): | ||||
| # adj_mat = r.adjacency_matrix | # adj_mat = r.adjacency_matrix | ||||