from icosagon.batch import PredictionsBatch from icosagon.declayer import Predictions, \ RelationPredictions, \ RelationFamilyPredictions from icosagon.trainprep import prepare_training, \ TrainValTest from icosagon.data import Data import torch def test_predictions_batch_01(): d = Data() d.add_node_type('Dummy', 5) fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) fam.add_relation_type('Dummy Rel', torch.tensor([ [0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0] ], dtype=torch.float32)) prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) assert len(prep_d.relation_families) == 1 assert len(prep_d.relation_families[0].relation_types) == 1 assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5 assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5 assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0 assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0 rel_pred = RelationPredictions( TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) ) fam_pred = RelationFamilyPredictions([ rel_pred ]) pred = Predictions([ fam_pred ]) batch = PredictionsBatch(pred, part_type='train', batch_size=1) count = 0 lst = [] for (input, target) in batch: assert len(input) == 1 assert len(target) == 1 lst.append((input[0], target[0])) count += 1 assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1), (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ] assert count == 10