|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from icosagon.batch import PredictionsBatch, \
- FlatPredictions, \
- flatten_predictions, \
- BatchIndices, \
- gather_batch_indices
- from icosagon.declayer import Predictions, \
- RelationPredictions, \
- RelationFamilyPredictions
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- from icosagon.data import Data
- import torch
- import pytest
-
-
- def test_flat_predictions_01():
- pred = FlatPredictions(torch.tensor([0, 1, 0, 1]),
- torch.tensor([1, 0, 1, 0]), 'train')
-
- assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1]))
- assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0]))
- assert pred.part_type == 'train'
-
-
- def test_flatten_predictions_01():
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- pred_flat = flatten_predictions(pred, part_type='train')
-
- assert torch.all(pred_flat.predictions == \
- torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32))
- assert torch.all(pred_flat.truth == \
- torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32))
- assert pred_flat.part_type == 'train'
-
-
- def test_flatten_predictions_02():
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- pred_flat = flatten_predictions(pred, part_type='val')
-
- assert len(pred_flat.predictions) == 0
- assert len(pred_flat.truth) == 0
- assert pred_flat.part_type == 'val'
-
-
- def test_flatten_predictions_03():
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- pred_flat = flatten_predictions(pred, part_type='test')
-
- assert len(pred_flat.predictions) == 0
- assert len(pred_flat.truth) == 0
- assert pred_flat.part_type == 'test'
-
-
- def test_flatten_predictions_04():
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- with pytest.raises(TypeError):
- pred_flat = flatten_predictions(1, part_type='test')
-
- with pytest.raises(ValueError):
- pred_flat = flatten_predictions(pred, part_type='x')
-
-
- def test_batch_indices_01():
- indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
- assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))
- assert indices.part_type == 'train'
-
-
- def test_gather_batch_indices_01():
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- pred_flat = flatten_predictions(pred, part_type='train')
-
- indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train')
-
- (input, target) = gather_batch_indices(pred_flat, indices)
- assert torch.all(input == \
- torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32))
- assert torch.all(target == \
- torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32))
-
-
- def test_predictions_batch_01():
- d = Data()
- d.add_node_type('Dummy', 5)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.tensor([
- [0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0],
- [0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1],
- [0, 1, 0, 0, 0]
- ], dtype=torch.float32))
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- assert len(prep_d.relation_families) == 1
- assert len(prep_d.relation_families[0].relation_types) == 1
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
- assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
-
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- pred_flat = flatten_predictions(pred, part_type='train')
-
- batch = PredictionsBatch(prep_d, part_type='train', batch_size=1)
- count = 0
- lst = []
- for indices in batch:
- (input, target) = gather_batch_indices(pred_flat, indices)
- assert len(input) == 1
- assert len(target) == 1
- lst.append((input[0], target[0]))
- count += 1
- assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1),
- (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ]
-
- assert count == 10
|