|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- from icosagon.input import OneHotInputLayer
- from icosagon.convolve import DropoutGraphConvActivation
- from icosagon.convlayer import DecagonLayer
- from icosagon.declayer import DecodeLayer, \
- Predictions, \
- RelationFamilyPredictions, \
- RelationPredictions
- from icosagon.decode import DEDICOMDecoder, \
- InnerProductDecoder
- from icosagon.data import Data
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- import torch
-
-
- def test_decode_layer_01():
- d = Data()
- d.add_node_type('Dummy', 100)
-
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Relation 1',
- torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
- in_layer = OneHotInputLayer(d)
- d_layer = DecagonLayer(in_layer.output_dim, 32, d)
- seq = torch.nn.Sequential(in_layer, d_layer)
- last_layer_repr = seq(None)
-
- dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
- activation=lambda x: x)
- pred = dec(last_layer_repr)
-
- assert isinstance(pred, Predictions)
-
- assert isinstance(pred.relation_families, list)
- assert len(pred.relation_families) == 1
- assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
-
- assert isinstance(pred.relation_families[0].relation_types, list)
- assert len(pred.relation_families[0].relation_types) == 1
- assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
-
- tmp = pred.relation_families[0].relation_types[0]
- assert isinstance(tmp.edges_pos, TrainValTest)
- assert isinstance(tmp.edges_neg, TrainValTest)
- assert isinstance(tmp.edges_back_pos, TrainValTest)
- assert isinstance(tmp.edges_back_neg, TrainValTest)
-
-
- def test_decode_layer_02():
- d = Data()
- d.add_node_type('Dummy', 100)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Relation 1',
- torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- in_layer = OneHotInputLayer(d)
- d_layer = DecagonLayer(in_layer.output_dim, 32, d)
- dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
- keep_prob=1., activation=lambda x: x)
- seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
-
- pred = seq(None)
-
- assert isinstance(pred, Predictions)
- assert len(pred.relation_families) == 1
- assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
- assert isinstance(pred.relation_families[0].relation_types, list)
- assert len(pred.relation_families[0].relation_types) == 1
-
-
- def test_decode_layer_03():
- d = Data()
- d.add_node_type('Dummy 1', 100)
- d.add_node_type('Dummy 2', 100)
- fam = d.add_relation_family('Dummy 1-Dummy 2', 0, 1, True)
- fam.add_relation_type('Dummy Relation 1',
- torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- in_layer = OneHotInputLayer(d)
- d_layer = DecagonLayer(in_layer.output_dim, 32, d)
- dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
- keep_prob=1., activation=lambda x: x)
- seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
-
- pred = seq(None)
- assert isinstance(pred, Predictions)
- assert len(pred.relation_families) == 1
- assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
- assert isinstance(pred.relation_families[0].relation_types, list)
- assert len(pred.relation_families[0].relation_types) == 1
- assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
-
-
- def test_decode_layer_04():
- d = Data()
- d.add_node_type('Dummy', 100)
- assert len(d.relation_families) == 0
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- in_layer = OneHotInputLayer(d)
- d_layer = DecagonLayer(in_layer.output_dim, 32, d)
- dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
- keep_prob=1., activation=lambda x: x)
- seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
-
- pred = seq(None)
-
- assert isinstance(pred, Predictions)
- assert len(pred.relation_families) == 0
-
-
- def test_decode_layer_05():
- d = Data()
- d.add_node_type('Dummy', 10)
- mat = torch.rand((10, 10))
- mat = (mat + mat.transpose(0, 1)) / 2
- mat = mat.round()
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, True,
- decoder_class=InnerProductDecoder)
- fam.add_relation_type('Dummy Rel', mat.to_sparse())
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- in_layer = OneHotInputLayer(d)
- conv_layer = DecagonLayer(in_layer.output_dim, 32, prep_d,
- rel_activation=lambda x: x, layer_activation=lambda x: x)
- dec_layer = DecodeLayer(conv_layer.output_dim, prep_d,
- keep_prob=1., activation=lambda x: x)
- seq = torch.nn.Sequential(in_layer, conv_layer, dec_layer)
-
- pred = seq(None)
- rel_pred = pred.relation_families[0].relation_types[0]
-
- for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
- edge_pred = getattr(rel_pred, edge_type)
- assert isinstance(edge_pred, TrainValTest)
- for part_type in ['train', 'val', 'test']:
- part_pred = getattr(edge_pred, part_type)
- assert isinstance(part_pred, torch.Tensor)
- assert len(part_pred.shape) == 1
- print(edge_type, part_type, part_pred.shape)
- if (edge_type, part_type) not in [('edges_pos', 'train'), ('edges_neg', 'train')]:
- assert part_pred.shape[0] == 0
- else:
- assert part_pred.shape[0] > 0
-
- prep_rel = prep_d.relation_families[0].relation_types[0]
- assert len(rel_pred.edges_pos.train) == len(prep_rel.edges_pos.train)
- assert len(rel_pred.edges_neg.train) == len(prep_rel.edges_neg.train)
-
- assert len(prep_rel.edges_pos.train) == torch.sum(mat)
-
- # print('Predictions for positive edges:')
- # print(rel_pred.edges_pos.train)
- # print('Predictions for negative edges:')
- # print(rel_pred.edges_neg.train)
-
- repr_in = in_layer(None)
- assert isinstance(repr_in, list)
- assert len(repr_in) == 1
- assert isinstance(repr_in[0], torch.Tensor)
- assert torch.all(repr_in[0].to_dense() == torch.eye(10))
-
- assert len(conv_layer.next_layer_repr[0]) == 1
- assert len(conv_layer.next_layer_repr[0][0].convolutions) == 1
- assert conv_layer.rel_activation(0) == 0
- assert conv_layer.rel_activation(1) == 1
- assert conv_layer.rel_activation(-1) == -1
- assert conv_layer.layer_activation(0) == 0
- assert conv_layer.layer_activation(1) == 1
- assert conv_layer.layer_activation(-1) == -1
-
- graph_conv = conv_layer.next_layer_repr[0][0].convolutions[0]
- assert isinstance(graph_conv, DropoutGraphConvActivation)
- assert graph_conv.activation(0) == 0
- assert graph_conv.activation(1) == 1
- assert graph_conv.activation(-1) == -1
- weight = graph_conv.graph_conv.weight
- adj_mat = prep_d.relation_families[0].relation_types[0].adjacency_matrix
- repr_conv = torch.sparse.mm(repr_in[0], weight)
- repr_conv = torch.mm(adj_mat, repr_conv)
- repr_conv = torch.nn.functional.normalize(repr_conv, p=2, dim=1)
- repr_conv_expect = conv_layer(repr_in)[0]
- print('repr_conv:\n', repr_conv)
- # print(repr_conv_expect)
- assert torch.all(repr_conv == repr_conv_expect)
- assert repr_conv.shape[1] == 32
-
- dec = InnerProductDecoder(32, 1, keep_prob=1., activation=lambda x: x)
- x, y = torch.meshgrid(torch.arange(0, 10), torch.arange(0, 10))
- x = x.flatten()
- y = y.flatten()
- repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0)
- repr_dec_expect = repr_dec_expect.view(10, 10)
-
- repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1))
- # repr_dec = torch.flatten(repr_dec)
- # repr_dec -= torch.eye(10)
- assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
-
- repr_dec_expect = torch.zeros((10, 10))
- x = prep_d.relation_families[0].relation_types[0].edges_pos.train
- repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
- x = prep_d.relation_families[0].relation_types[0].edges_neg.train
- repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
- print(repr_dec)
- print(repr_dec_expect)
-
- repr_dec = torch.zeros((10, 10))
- x = prep_d.relation_families[0].relation_types[0].edges_pos.train
- repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
- x = prep_d.relation_families[0].relation_types[0].edges_neg.train
- repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
-
- assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
-
- #print(prep_rel.edges_pos.train)
- #print(prep_rel.edges_neg.train)
-
- # assert isinstance(edge_pred.train)
- # assert isinstance(rel_pred.edges_pos, TrainValTest)
- # assert isinstance(rel_pred.edges_neg, TrainValTest)
- # assert isinstance(rel_pred.edges_back_pos, TrainValTest)
- # assert isinstance(rel_pred.edges_back_neg, TrainValTest)
|