IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add shuffle to PredictionsBatch.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
346cc747a6
共有 2 个文件被更改,包括 10 次插入2 次删除
  1. +9
    -1
      src/icosagon/batch.py
  2. +1
    -1
      src/icosagon/model.py

+ 9
- 1
src/icosagon/batch.py 查看文件

@@ -4,7 +4,7 @@ import torch
class PredictionsBatch(object):
def __init__(self, pred: Predictions, part_type: str = 'train',
batch_size: int = 100) -> None:
batch_size: int = 100, shuffle: bool = False) -> None:
if not isinstance(pred, Predictions):
raise TypeError('pred must be an instance of Predictions')
@@ -14,9 +14,12 @@ class PredictionsBatch(object):
batch_size = int(batch_size)
shuffle = bool(shuffle)
self.predictions = pred
self.part_type = part_type
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
edge_types = [('edges_pos', 1), ('edges_neg', 0),
@@ -35,5 +38,10 @@ class PredictionsBatch(object):
input = torch.cat(input)
target = torch.cat(target)
if self.shuffle:
perm = torch.randperm(len(input))
input = input[perm]
target = target[perm]
for i in range(0, len(input), self.batch_size):
yield (input[i:i+self.batch_size], target[i:i+self.batch_size])

+ 1
- 1
src/icosagon/model.py 查看文件

@@ -100,7 +100,7 @@ class Model(object):
for i in range(n - 1):
self.opt.zero_grad()
pred = self.seq(None)
batch = PredictionsBatch(pred, self.batch_size)
batch = PredictionsBatch(pred, self.batch_size, shuffle=True)
seed = torch.rand(1).item()
rng_state = torch.get_rng_state()
torch.manual_seed(seed)


正在加载...
取消
保存