|
@@ -9,7 +9,8 @@ from icosagon.trainprep import TrainValTest, \ |
|
|
get_edges_and_degrees, \
|
|
|
get_edges_and_degrees, \
|
|
|
prepare_adj_mat, \
|
|
|
prepare_adj_mat, \
|
|
|
prepare_relation_type, \
|
|
|
prepare_relation_type, \
|
|
|
prep_rel_one_node_type
|
|
|
|
|
|
|
|
|
prep_rel_one_node_type, \
|
|
|
|
|
|
prep_rel_two_node_types_sym
|
|
|
import torch
|
|
|
import torch
|
|
|
import pytest
|
|
|
import pytest
|
|
|
import numpy as np
|
|
|
import numpy as np
|
|
@@ -139,6 +140,33 @@ def test_prep_rel_one_node_type_01(): |
|
|
assert len(prep_rel.edges_back_neg.test) == 0
|
|
|
assert len(prep_rel.edges_back_neg.test) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_prep_rel_two_node_types_sym_01():
|
|
|
|
|
|
adj_mat = torch.zeros(200)
|
|
|
|
|
|
perm = torch.randperm(100)
|
|
|
|
|
|
adj_mat[perm[:10]] = 1
|
|
|
|
|
|
adj_mat = adj_mat.view(10, 20)
|
|
|
|
|
|
rel = RelationType('Dummy Relation', 0, 1, adj_mat, None)
|
|
|
|
|
|
ratios = TrainValTest(.8, .1, .1)
|
|
|
|
|
|
prep_rel = prep_rel_two_node_types_sym(rel, ratios)
|
|
|
|
|
|
assert prep_rel.name == rel.name
|
|
|
|
|
|
assert prep_rel.node_type_row == rel.node_type_row
|
|
|
|
|
|
assert prep_rel.node_type_column == rel.node_type_column
|
|
|
|
|
|
assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
|
|
|
|
|
|
assert prep_rel.adjacency_matrix_backward.shape == (20, 10)
|
|
|
|
|
|
assert len(prep_rel.edges_pos.train) == 8
|
|
|
|
|
|
assert len(prep_rel.edges_pos.val) == 1
|
|
|
|
|
|
assert len(prep_rel.edges_pos.test) == 1
|
|
|
|
|
|
assert len(prep_rel.edges_neg.train) == 8
|
|
|
|
|
|
assert len(prep_rel.edges_neg.val) == 1
|
|
|
|
|
|
assert len(prep_rel.edges_neg.test) == 1
|
|
|
|
|
|
|
|
|
|
|
|
assert len(prep_rel.edges_back_pos.train) == 0
|
|
|
|
|
|
assert len(prep_rel.edges_back_pos.val) == 0
|
|
|
|
|
|
assert len(prep_rel.edges_back_pos.test) == 0
|
|
|
|
|
|
assert len(prep_rel.edges_back_neg.train) == 0
|
|
|
|
|
|
assert len(prep_rel.edges_back_neg.val) == 0
|
|
|
|
|
|
assert len(prep_rel.edges_back_neg.test) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def prepare_relation(r, ratios):
|
|
|
# def prepare_relation(r, ratios):
|
|
|
# adj_mat = r.adjacency_matrix
|
|
|
# adj_mat = r.adjacency_matrix
|
|
|