diff --git a/src/triacontagon/batch.py b/src/triacontagon/batch.py index cfb367e..b93ab36 100644 --- a/src/triacontagon/batch.py +++ b/src/triacontagon/batch.py @@ -1,6 +1,7 @@ from .data import Data from .model import TrainingBatch import torch +from functools import reduce def _shuffle(x: torch.Tensor) -> torch.Tensor: @@ -8,6 +9,66 @@ def _shuffle(x: torch.Tensor) -> torch.Tensor: return x[order] +def _same_data_org(pos_data: Data, neg_data: Data): + if len(pos_data.vertex_types) != len(neg_data.vertex_types): + return False + + test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \ + and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \ + for i in range(len(pos_data.vertex_types)) ] + if not all(test): + return False + + if not set(pos_data.edge_types.keys()) == \ + set(neg_data.edge_types.keys()): + + return False + + test = [ pos_data.edge_types[i].name == \ + neg_data.edge_types[i].name \ + and pos_data.edge_types[i].vertex_type_row == \ + neg_data.edge_types[i].vertex_type_row \ + and pos_data.edge_types[i].vertex_type_column == \ + neg_data.edge_types[i].vertex_type_column \ + and len(pos_data.edge_types[i].adjacency_matrices) == \ + len(neg_data.edge_types[i].adjacency_matrices) \ + for i in pos_data.edge_types.keys() ] + if not all(test): + return False + + test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \ + len(neg_data.edge_types[i].adjacency_matrices[k].values()) \ + for k in range(len(pos_data.edge_types[i])) ] \ + for i in pos_data.edge_types.keys() ] + test = reduce(list.__add__, test) + if not all(test): + return False + + return True + + +class DualBatcher(object): + def __init__(self, pos_data: Data, neg_data: Data, + batch_size: int=512, shuffle: bool=True) -> None: + + if not isinstance(pos_data, Data): + raise TypeError('pos_data must be an instance of Data') + + if not isinstance(neg_data, Data): + raise TypeError('neg_data must be an instance of Data') + + if not _same_data_org(pos_data, neg_data): + raise ValueError('pos_data and neg_data must have the same organization') + + self.pos_data = pos_data + self.neg_data = neg_data + self.batch_size = int(batch_size) + self.shuffle = bool(shuffle) + + def __iter__(self): + + + class Batcher(object): def __init__(self, data: Data, batch_size: int=512, shuffle: bool=True) -> None: diff --git a/src/triacontagon/loop.py b/src/triacontagon/loop.py new file mode 100644 index 0000000..18565aa --- /dev/null +++ b/src/triacontagon/loop.py @@ -0,0 +1,40 @@ +from .model import Model +from .batch import Batcher + + +class TrainLoop(object): + def __init__(self, model: Model, + pos_batcher: Batcher, + neg_batcher: Batcher, + max_epochs: int = 50) -> None: + + if not isinstance(model, Model): + raise TypeError('model must be an instance of Model') + + if not isinstance(pos_batcher, Batcher): + raise TypeError('pos_batcher must be an instance of Batcher') + + if not isinstance(neg_batcher, Batcher): + raise TypeError('neg_batcher must be an instance of Batcher') + + self.model = model + self.pos_batcher = pos_batcher + self.neg_batcher = neg_batcher + self.max_epochs = int(num_epochs) + + def run_epoch(self) -> None: + pos_it = iter(self.pos_batcher) + neg_it = iter(self.neg_batcher) + + while True: + try: + pos_batch = next(pos_it) + neg_batch = next(neg_it) + except StopIteration: + break + if len(pos_batch.edges) != len(neg_batch.edges): + raise ValueError('Positive and negative batch should have same length') + + def run(self) -> None: + for epoch in range(self.max_epochs): + self.run_epoch()