From c210be41b3a615162c83d16d2128307d2a23066a Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 21 Jul 2020 14:14:07 +0200 Subject: [PATCH] Add databatch. --- src/icosagon/databatch.py | 91 ++++++++++++++++++++++++++++++++ tests/icosagon/test_databatch.py | 81 ++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 src/icosagon/databatch.py create mode 100644 tests/icosagon/test_databatch.py diff --git a/src/icosagon/databatch.py b/src/icosagon/databatch.py new file mode 100644 index 0000000..b7c4a8f --- /dev/null +++ b/src/icosagon/databatch.py @@ -0,0 +1,91 @@ +from icosagon.trainprep import PreparedData, \ + PreparedRelationFamily, \ + PreparedRelationType, \ + _empty_edge_list_tvt + + +class BatchedData(PreparedData): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +def batched_data_skeleton(data: PreparedData) -> BatchedData: + if not isinstance(data, PreparedData): + raise TypeError('data must be an instance of PreparedData') + + fam_skels = [] + for fam in data.relation_families: + rel_types_skel = [] + for rel in fam.relation_types: + rel_skel = PreparedRelationType(rel.name, + rel.node_type_row, rel.node_type_column, + rel.adjacency_matrix, rel.adjacency_matrix_backward, + _empty_edge_list_tvt(), _empty_edge_list_tvt(), + _empty_edge_list_tvt(), _empty_edge_list_tvt()) + rel_types_skel.append(rel_skel) + fam_skels.append(PreparedRelationFamily(fam.data, fam.name, + fam.node_type_row, fam.node_type_column, + fam.is_symmetric, fam.decoder_class, + rel_types_skel)) + return BatchedData(data.node_types, fam_skels) + + +class DataBatcher(object): + def __init__(self, data: PreparedData, batch_size: int) -> None: + self._check_params(data, batch_size) + + self.data = data + self.batch_size = batch_size + + # def batched_data_iter(self, fam_idx: int, rel_idx: int, + # part_type: str) -> BatchedData: + # + # rel = self.data.relation_families[fam_idx].relation_types[rel_idx] + # + # edges = getattr(rel.edges_pos, part_type) + # for m in range(0, len(edges), self.batch_size): + # batched_data = batched_data_skeleton(self.data) + # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos, + # part_type, edges[m : m + self.batch_size]) + # yield batched_data + # + # edges = getattr(rel.edges_neg, part_type) + # for m in range(0, len(edges), self.batch_size): + # batched_data = batched_data_skeleton(self.data) + # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg, + # part_type, edges[m : m + self.batch_size]) + # yield batched_data + # + # edges = getattr(rel.edges_pos_back, part_type) + # for m in range(0, len(edges), self.batch_size): + # batched_data = batched_data_skeleton(self.data) + # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back, + # part_type, edges[m : m + self.batch_size]) + # yield batched_data + # + # edges = getattr(rel.edges_neg_back, part_type) + # for m in range(0, len(), self.batch_size): + # batched_data = batched_data_skeleton(self.data) + # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back, + # edges[m : m + self.batch_size]) + # yield batched_data + + def __iter__(self) -> BatchedData: + for i, fam in enumerate(self.data.relation_families): + for k, rel in enumerate(fam.relation_types): + for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: + for part_type in ['train', 'val', 'test']: + edges = getattr(getattr(rel, edge_type), part_type) + for m in range(0, len(edges), self.batch_size): + batched_data = batched_data_skeleton(self.data) + setattr(getattr(batched_data.relation_families[i].relation_types[k], + edge_type), part_type, edges[m : m + self.batch_size]) + yield batched_data + + @staticmethod + def _check_params(data, batch_size): + if not isinstance(data, PreparedData): + raise TypeError('data must be an instance of PreparedData') + + if not isinstance(batch_size, int): + raise TypeError('batch_size must be an int') diff --git a/tests/icosagon/test_databatch.py b/tests/icosagon/test_databatch.py new file mode 100644 index 0000000..2a35843 --- /dev/null +++ b/tests/icosagon/test_databatch.py @@ -0,0 +1,81 @@ +from icosagon.databatch import DataBatcher, \ + BatchedData +from icosagon.data import Data +from icosagon.trainprep import prepare_training, \ + TrainValTest +import torch + + +def _some_data(): + data = Data() + data.add_node_type('Foo', 100) + data.add_node_type('Bar', 500) + fam = data.add_relation_family('Foo-Bar', 0, 1, True) + adj_mat = torch.rand(100, 500).round().to_sparse() + fam.add_relation_type('Foo-Bar', adj_mat) + return data + + +def test_data_batcher_01(): + data = _some_data() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + + +def test_data_batcher_02(): + data = _some_data() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + for batch_d in batcher: + pass + + +def test_data_batcher_03(): + data = _some_data() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + for batch_d in batcher: + edges_list = [] + for fam in batch_d.relation_families: + for rel in fam.relation_types: + for edge_type in ['edges_pos', 'edges_neg', + 'edges_back_pos', 'edges_back_neg']: + for part_type in ['train', 'val', 'test']: + edges = getattr(getattr(rel, edge_type), part_type) + edges_list.append(edges) + assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1 + + +def test_data_batcher_04(): + data = _some_data() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + edges_list = [] + for batch_d in batcher: + for fam in batch_d.relation_families: + for rel in fam.relation_types: + for edge_type in ['edges_pos', 'edges_neg', + 'edges_back_pos', 'edges_back_neg']: + for part_type in ['train', 'val', 'test']: + edges = getattr(getattr(rel, edge_type), part_type) + edges_list.append(edges) + assert sum([ len(edges) for edges in edges_list ]) == \ + torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2 + + +def test_data_batcher_05(): + data = _some_data() + prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) + batcher = DataBatcher(prep_d, 512) + for batch_d in batcher: + edges_list = [] + for fam in batch_d.relation_families: + for rel in fam.relation_types: + for edge_type in ['edges_pos', 'edges_neg', + 'edges_back_pos', 'edges_back_neg']: + for part_type in ['train', 'val', 'test']: + edges = getattr(getattr(rel, edge_type), part_type) + edges_list.append(edges) + assert all([ len(edges) <= 512 for edges in edges_list ]) + assert not all([ len(edges) == 0 for edges in edges_list ]) + print(sum(map(len, edges_list)))