|
|
@@ -1,6 +1,7 @@ |
|
|
|
from .data import Data
|
|
|
|
from .model import TrainingBatch
|
|
|
|
import torch
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
|
|
def _shuffle(x: torch.Tensor) -> torch.Tensor:
|
|
|
@@ -8,6 +9,66 @@ def _shuffle(x: torch.Tensor) -> torch.Tensor: |
|
|
|
return x[order]
|
|
|
|
|
|
|
|
|
|
|
|
def _same_data_org(pos_data: Data, neg_data: Data):
|
|
|
|
if len(pos_data.vertex_types) != len(neg_data.vertex_types):
|
|
|
|
return False
|
|
|
|
|
|
|
|
test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
|
|
|
|
and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
|
|
|
|
for i in range(len(pos_data.vertex_types)) ]
|
|
|
|
if not all(test):
|
|
|
|
return False
|
|
|
|
|
|
|
|
if not set(pos_data.edge_types.keys()) == \
|
|
|
|
set(neg_data.edge_types.keys()):
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
test = [ pos_data.edge_types[i].name == \
|
|
|
|
neg_data.edge_types[i].name \
|
|
|
|
and pos_data.edge_types[i].vertex_type_row == \
|
|
|
|
neg_data.edge_types[i].vertex_type_row \
|
|
|
|
and pos_data.edge_types[i].vertex_type_column == \
|
|
|
|
neg_data.edge_types[i].vertex_type_column \
|
|
|
|
and len(pos_data.edge_types[i].adjacency_matrices) == \
|
|
|
|
len(neg_data.edge_types[i].adjacency_matrices) \
|
|
|
|
for i in pos_data.edge_types.keys() ]
|
|
|
|
if not all(test):
|
|
|
|
return False
|
|
|
|
|
|
|
|
test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
|
|
|
|
len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
|
|
|
|
for k in range(len(pos_data.edge_types[i])) ] \
|
|
|
|
for i in pos_data.edge_types.keys() ]
|
|
|
|
test = reduce(list.__add__, test)
|
|
|
|
if not all(test):
|
|
|
|
return False
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
class DualBatcher(object):
|
|
|
|
def __init__(self, pos_data: Data, neg_data: Data,
|
|
|
|
batch_size: int=512, shuffle: bool=True) -> None:
|
|
|
|
|
|
|
|
if not isinstance(pos_data, Data):
|
|
|
|
raise TypeError('pos_data must be an instance of Data')
|
|
|
|
|
|
|
|
if not isinstance(neg_data, Data):
|
|
|
|
raise TypeError('neg_data must be an instance of Data')
|
|
|
|
|
|
|
|
if not _same_data_org(pos_data, neg_data):
|
|
|
|
raise ValueError('pos_data and neg_data must have the same organization')
|
|
|
|
|
|
|
|
self.pos_data = pos_data
|
|
|
|
self.neg_data = neg_data
|
|
|
|
self.batch_size = int(batch_size)
|
|
|
|
self.shuffle = bool(shuffle)
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Batcher(object):
|
|
|
|
def __init__(self, data: Data, batch_size: int=512,
|
|
|
|
shuffle: bool=True) -> None:
|
|
|
|