|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- from icosagon.trainprep import PreparedData, \
- PreparedRelationFamily, \
- PreparedRelationType, \
- _empty_edge_list_tvt
- import torch
- import random
-
-
- class BatchedData(PreparedData):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
-
- class BatchedDataPointer(object):
- def __init__(self, batched_data):
- self.batched_data = batched_data
-
-
- def batched_data_skeleton(data: PreparedData) -> BatchedData:
- if not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of PreparedData')
-
- fam_skels = []
- for fam in data.relation_families:
- rel_types_skel = []
- for rel in fam.relation_types:
- rel_skel = PreparedRelationType(rel.name,
- rel.node_type_row, rel.node_type_column,
- rel.adjacency_matrix, rel.adjacency_matrix_backward,
- _empty_edge_list_tvt(), _empty_edge_list_tvt(),
- _empty_edge_list_tvt(), _empty_edge_list_tvt())
- rel_types_skel.append(rel_skel)
- fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
- fam.node_type_row, fam.node_type_column,
- fam.is_symmetric, fam.decoder_class,
- rel_types_skel))
- return BatchedData(data.node_types, fam_skels)
-
-
- class DataBatcher(object):
- def __init__(self, data: PreparedData, batch_size: int,
- shuffle: bool = True) -> None:
- self._check_params(data, batch_size)
-
- self.data = data
- self.batch_size = batch_size
- self.shuffle = shuffle
-
- # def batched_data_iter(self, fam_idx: int, rel_idx: int,
- # part_type: str) -> BatchedData:
- #
- # rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
- #
- # edges = getattr(rel.edges_pos, part_type)
- # for m in range(0, len(edges), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
- # part_type, edges[m : m + self.batch_size])
- # yield batched_data
- #
- # edges = getattr(rel.edges_neg, part_type)
- # for m in range(0, len(edges), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
- # part_type, edges[m : m + self.batch_size])
- # yield batched_data
- #
- # edges = getattr(rel.edges_pos_back, part_type)
- # for m in range(0, len(edges), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
- # part_type, edges[m : m + self.batch_size])
- # yield batched_data
- #
- # edges = getattr(rel.edges_neg_back, part_type)
- # for m in range(0, len(), self.batch_size):
- # batched_data = batched_data_skeleton(self.data)
- # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
- # edges[m : m + self.batch_size])
- # yield batched_data
-
- def __iter__(self) -> BatchedData:
- gen = self.shuffle_iter() \
- if self.shuffle \
- else self.iter_base()
-
- for batched_data in gen:
- yield batched_data
-
- def iter_base(self) -> BatchedData:
- for i, fam in enumerate(self.data.relation_families):
- for k, rel in enumerate(fam.relation_types):
- for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
- for part_type in ['train', 'val', 'test']:
- edges = getattr(getattr(rel, edge_type), part_type)
- if self.shuffle:
- perm = torch.randperm(len(edges))
- edges = edges[perm]
- for m in range(0, len(edges), self.batch_size):
- batched_data = batched_data_skeleton(self.data)
- setattr(getattr(batched_data.relation_families[i].relation_types[k],
- edge_type), part_type, edges[m : m + self.batch_size])
- yield batched_data
-
- def shuffle_iter(self) -> BatchedData:
- res = list(self.iter_base())
- random.shuffle(res)
- for batched_data in res:
- yield batched_data
-
- @staticmethod
- def _check_params(data, batch_size):
- if not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of PreparedData')
-
- if not isinstance(batch_size, int):
- raise TypeError('batch_size must be an int')
|