IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Add shuffle to DataBatcher.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
45a18a46aa
共有 1 個文件被更改,包括 22 次插入1 次删除
  1. +22
    -1
      src/icosagon/databatch.py

+ 22
- 1
src/icosagon/databatch.py 查看文件

@@ -2,6 +2,8 @@ from icosagon.trainprep import PreparedData, \
PreparedRelationFamily, \
PreparedRelationType, \
_empty_edge_list_tvt
import torch
import random
class BatchedData(PreparedData):
@@ -31,11 +33,13 @@ def batched_data_skeleton(data: PreparedData) -> BatchedData:
class DataBatcher(object):
def __init__(self, data: PreparedData, batch_size: int) -> None:
def __init__(self, data: PreparedData, batch_size: int,
shuffle: bool = True) -> None:
self._check_params(data, batch_size)
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
# def batched_data_iter(self, fam_idx: int, rel_idx: int,
# part_type: str) -> BatchedData:
@@ -71,17 +75,34 @@ class DataBatcher(object):
# yield batched_data
def __iter__(self) -> BatchedData:
gen = self.shuffle_iter() \
if self.shuffle \
else self.iter_base()
for batched_data in gen:
yield batched_data
def iter_base(self) -> BatchedData:
for i, fam in enumerate(self.data.relation_families):
for k, rel in enumerate(fam.relation_types):
for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
for part_type in ['train', 'val', 'test']:
edges = getattr(getattr(rel, edge_type), part_type)
if self.shuffle:
perm = torch.randperm(len(edges))
edges = edges[perm]
for m in range(0, len(edges), self.batch_size):
batched_data = batched_data_skeleton(self.data)
setattr(getattr(batched_data.relation_families[i].relation_types[k],
edge_type), part_type, edges[m : m + self.batch_size])
yield batched_data
def shuffle_iter(self) -> BatchedData:
res = list(self.iter_base())
random.shuffle(res)
for batched_data in res:
yield batched_data
@staticmethod
def _check_params(data, batch_size):
if not isinstance(data, PreparedData):


Loading…
取消
儲存