|
- from icosagon.data import Data, \
- _equal
- from icosagon.model import Model
- from icosagon.trainprep import PreparedData, \
- PreparedRelationFamily, \
- PreparedRelationType, \
- TrainValTest, \
- norm_adj_mat_one_node_type, \
- prepare_training
- import torch
- from icosagon.input import OneHotInputLayer
- from icosagon.convlayer import DecagonLayer
- from icosagon.declayer import DecodeLayer
- import pytest
-
-
- def _is_identity_function(f):
- for x in range(-100, 101):
- if f(x) != x:
- return False
- return True
-
-
- def test_model_01():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- m = Model(prep_d)
-
- assert m.prep_d == prep_d
- assert m.layer_dimensions == [32, 64]
- assert m.keep_prob == 1.
- assert _is_identity_function(m.rel_activation)
- assert m.layer_activation == torch.nn.functional.relu
- assert _is_identity_function(m.dec_activation)
- assert isinstance(m.seq, torch.nn.Sequential)
-
-
- def test_model_02():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel', mat)
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- m = Model(prep_d)
-
- assert isinstance(m.prep_d, PreparedData)
- assert isinstance(m.prep_d.relation_families, list)
- assert len(m.prep_d.relation_families) == 1
- assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
- assert len(m.prep_d.relation_families[0].relation_types) == 1
- assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
- assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
- assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
- norm_adj_mat_one_node_type(mat)))
-
- assert isinstance(m.seq[0], OneHotInputLayer)
- assert isinstance(m.seq[1], DecagonLayer)
- assert isinstance(m.seq[2], DecagonLayer)
- assert isinstance(m.seq[3], DecodeLayer)
- assert len(m.seq) == 4
-
-
- def test_model_03():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel', mat)
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- m = Model(prep_d)
-
- assert len(list(m.seq[0].parameters())) == 1
- assert len(list(m.seq[1].parameters())) == 1
- assert len(list(m.seq[2].parameters())) == 1
- assert len(list(m.seq[3].parameters())) == 2
-
-
- def test_model_04():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel 1', mat)
- fam.add_relation_type('Dummy Rel 2', mat.clone())
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- m = Model(prep_d)
-
- assert len(list(m.seq[0].parameters())) == 1
- assert len(list(m.seq[1].parameters())) == 2
- assert len(list(m.seq[2].parameters())) == 2
- assert len(list(m.seq[3].parameters())) == 3
-
-
- def test_model_05():
- d = Data()
- d.add_node_type('Dummy', 10)
-
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel 1', mat)
- fam.add_relation_type('Dummy Rel 2', mat.clone())
-
- fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel 2-1', mat)
- fam.add_relation_type('Dummy Rel 2-2', mat.clone())
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- m = Model(prep_d)
-
- assert len(list(m.seq[0].parameters())) == 1
- assert len(list(m.seq[1].parameters())) == 4
- assert len(list(m.seq[2].parameters())) == 4
- assert len(list(m.seq[3].parameters())) == 6
-
-
- def test_model_06():
- d = Data()
- d.add_node_type('Dummy', 10)
- d.add_node_type('Foobar', 20)
-
- fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
- mat = torch.rand(10, 20).round().to_sparse()
- fam.add_relation_type('Dummy Rel 1', mat)
- fam.add_relation_type('Dummy Rel 2', mat.clone())
-
- fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel 2-1', mat)
- fam.add_relation_type('Dummy Rel 2-2', mat.clone())
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- m = Model(prep_d)
-
- assert len(list(m.seq[0].parameters())) == 2
- assert len(list(m.seq[1].parameters())) == 6
- assert len(list(m.seq[2].parameters())) == 6
- assert len(list(m.seq[3].parameters())) == 6
-
-
- def test_model_07():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- with pytest.raises(TypeError):
- m = Model(1)
-
- with pytest.raises(TypeError):
- m = Model(prep_d, layer_dimensions=1)
-
- with pytest.raises(TypeError):
- m = Model(prep_d, ratios=1)
-
- with pytest.raises(ValueError):
- m = Model(prep_d, keep_prob='x')
-
- with pytest.raises(TypeError):
- m = Model(prep_d, rel_activation='x')
-
- with pytest.raises(TypeError):
- m = Model(prep_d, layer_activation='x')
-
- with pytest.raises(TypeError):
- m = Model(prep_d, dec_activation='x')
-
-
- def test_model_08():
- d = Data()
- d.add_node_type('Dummy', 10)
- d.add_node_type('Foobar', 20)
-
- fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
- mat = torch.rand(10, 20).round().to_sparse()
- fam.add_relation_type('Dummy Rel 1', mat)
- fam.add_relation_type('Dummy Rel 2', mat.clone())
-
- fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
- mat = torch.rand(10, 10).round().to_sparse()
- fam.add_relation_type('Dummy Rel 2-1', mat)
- fam.add_relation_type('Dummy Rel 2-2', mat.clone())
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- m = Model(prep_d)
-
- assert len(list(m.parameters())) == 20
|