|
|
@@ -2,6 +2,8 @@ from icosagon.trainprep import PreparedData, \ |
|
|
|
PreparedRelationFamily, \
|
|
|
|
PreparedRelationType, \
|
|
|
|
_empty_edge_list_tvt
|
|
|
|
import torch
|
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
|
|
class BatchedData(PreparedData):
|
|
|
@@ -31,11 +33,13 @@ def batched_data_skeleton(data: PreparedData) -> BatchedData: |
|
|
|
|
|
|
|
|
|
|
|
class DataBatcher(object):
|
|
|
|
def __init__(self, data: PreparedData, batch_size: int) -> None:
|
|
|
|
def __init__(self, data: PreparedData, batch_size: int,
|
|
|
|
shuffle: bool = True) -> None:
|
|
|
|
self._check_params(data, batch_size)
|
|
|
|
|
|
|
|
self.data = data
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.shuffle = shuffle
|
|
|
|
|
|
|
|
# def batched_data_iter(self, fam_idx: int, rel_idx: int,
|
|
|
|
# part_type: str) -> BatchedData:
|
|
|
@@ -71,17 +75,34 @@ class DataBatcher(object): |
|
|
|
# yield batched_data
|
|
|
|
|
|
|
|
def __iter__(self) -> BatchedData:
|
|
|
|
gen = self.shuffle_iter() \
|
|
|
|
if self.shuffle \
|
|
|
|
else self.iter_base()
|
|
|
|
|
|
|
|
for batched_data in gen:
|
|
|
|
yield batched_data
|
|
|
|
|
|
|
|
def iter_base(self) -> BatchedData:
|
|
|
|
for i, fam in enumerate(self.data.relation_families):
|
|
|
|
for k, rel in enumerate(fam.relation_types):
|
|
|
|
for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
|
|
|
|
for part_type in ['train', 'val', 'test']:
|
|
|
|
edges = getattr(getattr(rel, edge_type), part_type)
|
|
|
|
if self.shuffle:
|
|
|
|
perm = torch.randperm(len(edges))
|
|
|
|
edges = edges[perm]
|
|
|
|
for m in range(0, len(edges), self.batch_size):
|
|
|
|
batched_data = batched_data_skeleton(self.data)
|
|
|
|
setattr(getattr(batched_data.relation_families[i].relation_types[k],
|
|
|
|
edge_type), part_type, edges[m : m + self.batch_size])
|
|
|
|
yield batched_data
|
|
|
|
|
|
|
|
def shuffle_iter(self) -> BatchedData:
|
|
|
|
res = list(self.iter_base())
|
|
|
|
random.shuffle(res)
|
|
|
|
for batched_data in res:
|
|
|
|
yield batched_data
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _check_params(data, batch_size):
|
|
|
|
if not isinstance(data, PreparedData):
|
|
|
|