|
@@ -4,7 +4,7 @@ import torch |
|
|
|
|
|
|
|
|
class PredictionsBatch(object):
|
|
|
class PredictionsBatch(object):
|
|
|
def __init__(self, pred: Predictions, part_type: str = 'train',
|
|
|
def __init__(self, pred: Predictions, part_type: str = 'train',
|
|
|
batch_size: int = 100) -> None:
|
|
|
|
|
|
|
|
|
batch_size: int = 100, shuffle: bool = False) -> None:
|
|
|
|
|
|
|
|
|
if not isinstance(pred, Predictions):
|
|
|
if not isinstance(pred, Predictions):
|
|
|
raise TypeError('pred must be an instance of Predictions')
|
|
|
raise TypeError('pred must be an instance of Predictions')
|
|
@@ -14,9 +14,12 @@ class PredictionsBatch(object): |
|
|
|
|
|
|
|
|
batch_size = int(batch_size)
|
|
|
batch_size = int(batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
shuffle = bool(shuffle)
|
|
|
|
|
|
|
|
|
self.predictions = pred
|
|
|
self.predictions = pred
|
|
|
self.part_type = part_type
|
|
|
self.part_type = part_type
|
|
|
self.batch_size = batch_size
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
self.shuffle = shuffle
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
def __iter__(self):
|
|
|
edge_types = [('edges_pos', 1), ('edges_neg', 0),
|
|
|
edge_types = [('edges_pos', 1), ('edges_neg', 0),
|
|
@@ -35,5 +38,10 @@ class PredictionsBatch(object): |
|
|
input = torch.cat(input)
|
|
|
input = torch.cat(input)
|
|
|
target = torch.cat(target)
|
|
|
target = torch.cat(target)
|
|
|
|
|
|
|
|
|
|
|
|
if self.shuffle:
|
|
|
|
|
|
perm = torch.randperm(len(input))
|
|
|
|
|
|
input = input[perm]
|
|
|
|
|
|
target = target[perm]
|
|
|
|
|
|
|
|
|
for i in range(0, len(input), self.batch_size):
|
|
|
for i in range(0, len(input), self.batch_size):
|
|
|
yield (input[i:i+self.batch_size], target[i:i+self.batch_size])
|
|
|
yield (input[i:i+self.batch_size], target[i:i+self.batch_size])
|