|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- from icosagon.batch import PredictionsBatch
- from icosagon.declayer import Predictions, \
- RelationPredictions, \
- RelationFamilyPredictions
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- from icosagon.data import Data
- import torch
-
-
- def test_predictions_batch_01():
- d = Data()
- d.add_node_type('Dummy', 5)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.tensor([
- [0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0],
- [0, 0, 0, 1, 0],
- [0, 0, 0, 0, 1],
- [0, 1, 0, 0, 0]
- ], dtype=torch.float32))
-
- prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
-
- assert len(prep_d.relation_families) == 1
- assert len(prep_d.relation_families[0].relation_types) == 1
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
- assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
- assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
-
- rel_pred = RelationPredictions(
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
- TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
- )
- fam_pred = RelationFamilyPredictions([ rel_pred ])
- pred = Predictions([ fam_pred ])
-
- batch = PredictionsBatch(pred, part_type='train', batch_size=1)
- count = 0
- lst = []
- for (input, target) in batch:
- assert len(input) == 1
- assert len(target) == 1
- lst.append((input[0], target[0]))
- count += 1
- assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1),
- (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ]
-
- assert count == 10
|