|
- from icosagon.trainprep import PreparedData, \
- PreparedRelationFamily, \
- PreparedRelationType, \
- _empty_edge_list_tvt
-
-
- class BatchedData(PreparedData):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
-
- def batched_data_skeleton(data: PreparedData) -> BatchedData:
- if not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of PreparedData')
-
- fam_skels = []
- for fam in data.relation_families:
- rel_types_skel = []
- for rel in fam.relation_types:
- rel_skel = PreparedRelationType(rel.name,
- rel.node_type_row, rel.node_type_column,
- rel.adjacency_matrix, rel.adjacency_matrix_backward,
- _empty_edge_list_tvt(), _empty_edge_list_tvt(),
- _empty_edge_list_tvt(), _empty_edge_list_tvt())
- rel_types_skel.append(rel_skel)
- fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
- fam.node_type_row, fam.node_type_column,
- fam.is_symmetric, fam.decoder_class,
- rel_types_skel))
- return BatchedData(data.node_types, fam_skels)
-
-
- class DataBatcher(object):
- def __init__(self, data: PreparedData, batch_size: int) -> None:
- self._check_params(data, batch_size)
-
- self.data = data
- self.batch_size = batch_size
-
- # def batched_data_iter(self, fam_idx: int, rel_idx: int,
- # part_type: str) -> BatchedData:
- #
- # rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
- #
- # edges = getattr(rel.edges_pos, part_type)
- # for m in range(0, len(edges), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
- # part_type, edges[m : m + self.batch_size])
- # yield batched_data
- #
- # edges = getattr(rel.edges_neg, part_type)
- # for m in range(0, len(edges), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
- # part_type, edges[m : m + self.batch_size])
- # yield batched_data
- #
- # edges = getattr(rel.edges_pos_back, part_type)
- # for m in range(0, len(edges), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
- # part_type, edges[m : m + self.batch_size])
- # yield batched_data
- #
- # edges = getattr(rel.edges_neg_back, part_type)
- # for m in range(0, len(), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
- # edges[m : m + self.batch_size])
- # yield batched_data
-
- def __iter__(self) -> BatchedData:
- for i, fam in enumerate(self.data.relation_families):
- for k, rel in enumerate(fam.relation_types):
- for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
- for part_type in ['train', 'val', 'test']:
- edges = getattr(getattr(rel, edge_type), part_type)
- for m in range(0, len(edges), self.batch_size):
- batched_data = batched_data_skeleton(self.data)
- setattr(getattr(batched_data.relation_families[i].relation_types[k],
- edge_type), part_type, edges[m : m + self.batch_size])
- yield batched_data
-
- @staticmethod
- def _check_params(data, batch_size):
- if not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of PreparedData')
-
- if not isinstance(batch_size, int):
- raise TypeError('batch_size must be an int')
|