|
|
@@ -0,0 +1,91 @@ |
|
|
|
from icosagon.trainprep import PreparedData, \
|
|
|
|
PreparedRelationFamily, \
|
|
|
|
PreparedRelationType, \
|
|
|
|
_empty_edge_list_tvt
|
|
|
|
|
|
|
|
|
|
|
|
class BatchedData(PreparedData):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def batched_data_skeleton(data: PreparedData) -> BatchedData:
|
|
|
|
if not isinstance(data, PreparedData):
|
|
|
|
raise TypeError('data must be an instance of PreparedData')
|
|
|
|
|
|
|
|
fam_skels = []
|
|
|
|
for fam in data.relation_families:
|
|
|
|
rel_types_skel = []
|
|
|
|
for rel in fam.relation_types:
|
|
|
|
rel_skel = PreparedRelationType(rel.name,
|
|
|
|
rel.node_type_row, rel.node_type_column,
|
|
|
|
rel.adjacency_matrix, rel.adjacency_matrix_backward,
|
|
|
|
_empty_edge_list_tvt(), _empty_edge_list_tvt(),
|
|
|
|
_empty_edge_list_tvt(), _empty_edge_list_tvt())
|
|
|
|
rel_types_skel.append(rel_skel)
|
|
|
|
fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
|
|
|
|
fam.node_type_row, fam.node_type_column,
|
|
|
|
fam.is_symmetric, fam.decoder_class,
|
|
|
|
rel_types_skel))
|
|
|
|
return BatchedData(data.node_types, fam_skels)
|
|
|
|
|
|
|
|
|
|
|
|
class DataBatcher(object):
|
|
|
|
def __init__(self, data: PreparedData, batch_size: int) -> None:
|
|
|
|
self._check_params(data, batch_size)
|
|
|
|
|
|
|
|
self.data = data
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
|
|
# def batched_data_iter(self, fam_idx: int, rel_idx: int,
|
|
|
|
# part_type: str) -> BatchedData:
|
|
|
|
#
|
|
|
|
# rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
|
|
|
|
#
|
|
|
|
# edges = getattr(rel.edges_pos, part_type)
|
|
|
|
# for m in range(0, len(edges), self.batch_size):
|
|
|
|
# batched_data = batched_data_skeleton(self.data)
|
|
|
|
# setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
|
|
|
|
# part_type, edges[m : m + self.batch_size])
|
|
|
|
# yield batched_data
|
|
|
|
#
|
|
|
|
# edges = getattr(rel.edges_neg, part_type)
|
|
|
|
# for m in range(0, len(edges), self.batch_size):
|
|
|
|
# batched_data = batched_data_skeleton(self.data)
|
|
|
|
# setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
|
|
|
|
# part_type, edges[m : m + self.batch_size])
|
|
|
|
# yield batched_data
|
|
|
|
#
|
|
|
|
# edges = getattr(rel.edges_pos_back, part_type)
|
|
|
|
# for m in range(0, len(edges), self.batch_size):
|
|
|
|
# batched_data = batched_data_skeleton(self.data)
|
|
|
|
# setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
|
|
|
|
# part_type, edges[m : m + self.batch_size])
|
|
|
|
# yield batched_data
|
|
|
|
#
|
|
|
|
# edges = getattr(rel.edges_neg_back, part_type)
|
|
|
|
# for m in range(0, len(), self.batch_size):
|
|
|
|
# batched_data = batched_data_skeleton(self.data)
|
|
|
|
# setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
|
|
|
|
# edges[m : m + self.batch_size])
|
|
|
|
# yield batched_data
|
|
|
|
|
|
|
|
def __iter__(self) -> BatchedData:
|
|
|
|
for i, fam in enumerate(self.data.relation_families):
|
|
|
|
for k, rel in enumerate(fam.relation_types):
|
|
|
|
for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
|
|
|
|
for part_type in ['train', 'val', 'test']:
|
|
|
|
edges = getattr(getattr(rel, edge_type), part_type)
|
|
|
|
for m in range(0, len(edges), self.batch_size):
|
|
|
|
batched_data = batched_data_skeleton(self.data)
|
|
|
|
setattr(getattr(batched_data.relation_families[i].relation_types[k],
|
|
|
|
edge_type), part_type, edges[m : m + self.batch_size])
|
|
|
|
yield batched_data
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _check_params(data, batch_size):
|
|
|
|
if not isinstance(data, PreparedData):
|
|
|
|
raise TypeError('data must be an instance of PreparedData')
|
|
|
|
|
|
|
|
if not isinstance(batch_size, int):
|
|
|
|
raise TypeError('batch_size must be an int')
|