IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Selaa lähdekoodia

Add PredictionsBatch and test_predictions_btch_01().

master
Stanislaw Adaszewski 4 vuotta sitten
vanhempi
commit
a9f14d14a8
2 muutettua tiedostoa jossa 85 lisäystä ja 0 poistoa
  1. +39
    -0
      src/icosagon/batch.py
  2. +46
    -0
      tests/icosagon/test_batch.py

+ 39
- 0
src/icosagon/batch.py Näytä tiedosto

@@ -0,0 +1,39 @@
from icosagon.declayer import Predictions
import torch
class PredictionsBatch(object):
def __init__(self, pred: Predictions, part_type: str = 'train',
batch_size: int = 100) -> None:
if not isinstance(pred, Predictions):
raise TypeError('pred must be an instance of Predictions')
if part_type not in ['train', 'val', 'test']:
raise ValueError('part_type must be set to train, val or test')
batch_size = int(batch_size)
self.predictions = pred
self.part_type = part_type
self.batch_size = batch_size
def __iter__(self):
edge_types = [('edges_pos', 1), ('edges_neg', 0),
('edges_back_pos', 1), ('edges_back_neg', 0)]
input = []
target = []
for fam in self.predictions.relation_families:
for rel in fam.relation_types:
for (et, tgt) in edge_types:
edge_pred = getattr(getattr(rel, et), self.part_type)
input.append(edge_pred)
target.append(torch.ones_like(edge_pred) * tgt)
input = torch.cat(input)
target = torch.cat(target)
for i in range(0, len(input), self.batch_size):
yield (input[i:i+self.batch_size], target[i:i+self.batch_size])

+ 46
- 0
tests/icosagon/test_batch.py Näytä tiedosto

@@ -0,0 +1,46 @@
from icosagon.batch import PredictionsBatch
from icosagon.declayer import Predictions, \
RelationPredictions, \
RelationFamilyPredictions
from icosagon.trainprep import prepare_training, \
TrainValTest
from icosagon.data import Data
import torch
def test_predictions_batch_01():
d = Data()
d.add_node_type('Dummy', 5)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', torch.tensor([
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0]
], dtype=torch.float32))
prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
assert len(prep_d.relation_families) == 1
assert len(prep_d.relation_families[0].relation_types) == 1
assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
rel_pred = RelationPredictions(
TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
)
fam_pred = RelationFamilyPredictions([ rel_pred ])
pred = Predictions([ fam_pred ])
batch = PredictionsBatch(pred, part_type='train', batch_size=1)
count = 0
for (input, target) in batch:
count += 1
assert count == 10

Loading…
Peruuta
Tallenna