from icosagon.batch import PredictionsBatch, \ FlatPredictions, \ flatten_predictions, \ BatchIndices, \ gather_batch_indices from icosagon.declayer import Predictions, \ RelationPredictions, \ RelationFamilyPredictions from icosagon.trainprep import prepare_training, \ TrainValTest from icosagon.data import Data import torch import pytest def test_flat_predictions_01(): pred = FlatPredictions(torch.tensor([0, 1, 0, 1]), torch.tensor([1, 0, 1, 0]), 'train') assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1])) assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0])) assert pred.part_type == 'train' def test_flatten_predictions_01(): rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) pred_flat = flatten_predictions(pred, part_type='train') assert torch.all(pred_flat.predictions == \ torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32)) assert torch.all(pred_flat.truth == \ torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32)) assert pred_flat.part_type == 'train' def test_flatten_predictions_02(): rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) pred_flat = flatten_predictions(pred, part_type='val') assert len(pred_flat.predictions) == 0 assert len(pred_flat.truth) == 0 assert pred_flat.part_type == 'val' def test_flatten_predictions_03(): rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) pred_flat = flatten_predictions(pred, part_type='test') assert len(pred_flat.predictions) == 0 assert len(pred_flat.truth) == 0 assert pred_flat.part_type == 'test' def test_flatten_predictions_04(): rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) with pytest.raises(TypeError): pred_flat = flatten_predictions(1, part_type='test') with pytest.raises(ValueError): pred_flat = flatten_predictions(pred, part_type='x') def test_flatten_predictions_05(): x = torch.rand(5000) y = torch.cat([ x, x ]) z = torch.cat([ torch.ones(5000), torch.zeros(5000) ]) rel_pred = RelationPredictions( TrainValTest(x, torch.zeros(0), torch.zeros(0)), TrainValTest(x, torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) for _ in range(10): pred_flat = flatten_predictions(pred, part_type='train') assert torch.all(pred_flat.predictions == y) assert torch.all(pred_flat.truth == z) assert pred_flat.part_type == 'train' def test_batch_indices_01(): indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train') assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4])) assert indices.part_type == 'train' def test_gather_batch_indices_01(): rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) pred_flat = flatten_predictions(pred, part_type='train') indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train') (input, target) = gather_batch_indices(pred_flat, indices) assert torch.all(input == \ torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32)) assert torch.all(target == \ torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32)) def test_predictions_batch_01(): d = Data() d.add_node_type('Dummy', 5) fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) fam.add_relation_type('Dummy Rel', torch.tensor([ [0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0] ], dtype=torch.float32)) prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) assert len(prep_d.relation_families) == 1 assert len(prep_d.relation_families[0].relation_types) == 1 assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5 assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5 assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0 assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0 rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) pred_flat = flatten_predictions(pred, part_type='train') batch = PredictionsBatch(prep_d, part_type='train', batch_size=1) count = 0 lst = [] for indices in batch: (input, target) = gather_batch_indices(pred_flat, indices) assert len(input) == 1 assert len(target) == 1 lst.append((input[0], target[0])) count += 1 assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1), (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ] assert count == 10