From a9f14d14a80888fe2d0b644f3f52dcc188132bed Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 16 Jun 2020 15:11:25 +0200 Subject: [PATCH] Add PredictionsBatch and test_predictions_btch_01(). --- src/icosagon/batch.py | 39 ++++++++++++++++++++++++++++++ tests/icosagon/test_batch.py | 46 ++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 src/icosagon/batch.py create mode 100644 tests/icosagon/test_batch.py diff --git a/src/icosagon/batch.py b/src/icosagon/batch.py new file mode 100644 index 0000000..275cac3 --- /dev/null +++ b/src/icosagon/batch.py @@ -0,0 +1,39 @@ +from icosagon.declayer import Predictions +import torch + + +class PredictionsBatch(object): + def __init__(self, pred: Predictions, part_type: str = 'train', + batch_size: int = 100) -> None: + + if not isinstance(pred, Predictions): + raise TypeError('pred must be an instance of Predictions') + + if part_type not in ['train', 'val', 'test']: + raise ValueError('part_type must be set to train, val or test') + + batch_size = int(batch_size) + + self.predictions = pred + self.part_type = part_type + self.batch_size = batch_size + + def __iter__(self): + edge_types = [('edges_pos', 1), ('edges_neg', 0), + ('edges_back_pos', 1), ('edges_back_neg', 0)] + + input = [] + target = [] + + for fam in self.predictions.relation_families: + for rel in fam.relation_types: + for (et, tgt) in edge_types: + edge_pred = getattr(getattr(rel, et), self.part_type) + input.append(edge_pred) + target.append(torch.ones_like(edge_pred) * tgt) + + input = torch.cat(input) + target = torch.cat(target) + + for i in range(0, len(input), self.batch_size): + yield (input[i:i+self.batch_size], target[i:i+self.batch_size]) diff --git a/tests/icosagon/test_batch.py b/tests/icosagon/test_batch.py new file mode 100644 index 0000000..fa6a4a9 --- /dev/null +++ b/tests/icosagon/test_batch.py @@ -0,0 +1,46 @@ +from icosagon.batch import PredictionsBatch +from icosagon.declayer import Predictions, \ + RelationPredictions, \ + RelationFamilyPredictions +from icosagon.trainprep import prepare_training, \ + TrainValTest +from icosagon.data import Data +import torch + + +def test_predictions_batch_01(): + d = Data() + d.add_node_type('Dummy', 5) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', torch.tensor([ + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0] + ], dtype=torch.float32)) + + prep_d = prepare_training(d, TrainValTest(1., 0., 0.)) + + assert len(prep_d.relation_families) == 1 + assert len(prep_d.relation_families[0].relation_types) == 1 + assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5 + assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5 + assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0 + assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0 + + rel_pred = RelationPredictions( + TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)), + TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)) + ) + fam_pred = RelationFamilyPredictions([ rel_pred ]) + pred = Predictions([ fam_pred ]) + + batch = PredictionsBatch(pred, part_type='train', batch_size=1) + count = 0 + for (input, target) in batch: + count += 1 + + assert count == 10