|
- from .data import Data
- from .model import TrainingBatch
- import torch
- from functools import reduce
-
-
- def _shuffle(x: torch.Tensor) -> torch.Tensor:
- order = torch.randperm(len(x))
- return x[order]
-
-
- def _same_data_org(pos_data: Data, neg_data: Data):
- if len(pos_data.vertex_types) != len(neg_data.vertex_types):
- return False
-
- test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
- and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
- for i in range(len(pos_data.vertex_types)) ]
- if not all(test):
- return False
-
- if not set(pos_data.edge_types.keys()) == \
- set(neg_data.edge_types.keys()):
-
- return False
-
- test = [ pos_data.edge_types[i].name == \
- neg_data.edge_types[i].name \
- and pos_data.edge_types[i].vertex_type_row == \
- neg_data.edge_types[i].vertex_type_row \
- and pos_data.edge_types[i].vertex_type_column == \
- neg_data.edge_types[i].vertex_type_column \
- and len(pos_data.edge_types[i].adjacency_matrices) == \
- len(neg_data.edge_types[i].adjacency_matrices) \
- for i in pos_data.edge_types.keys() ]
- if not all(test):
- return False
-
- test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
- len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
- for k in range(len(pos_data.edge_types[i].adjacency_matrices)) ] \
- for i in pos_data.edge_types.keys() ]
- test = reduce(list.__add__, test, [])
- if not all(test):
- return False
-
- return True
-
-
- class DualBatcher(object):
- def __init__(self, pos_data: Data, neg_data: Data,
- batch_size: int=512, shuffle: bool=True) -> None:
-
- if not isinstance(pos_data, Data):
- raise TypeError('pos_data must be an instance of Data')
-
- if not isinstance(neg_data, Data):
- raise TypeError('neg_data must be an instance of Data')
-
- if not _same_data_org(pos_data, neg_data):
- raise ValueError('pos_data and neg_data must have the same organization')
-
- self.pos_data = pos_data
- self.neg_data = neg_data
- self.batch_size = int(batch_size)
- self.shuffle = bool(shuffle)
-
- def get_edge_lists(self, data: Data):
- edge_types = list(data.edge_types.items())
- edge_keys = [ a[0] for a in edge_types ]
- edge_types = [ a[1] for a in edge_types ]
-
- edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
- for adj_mat in et.adjacency_matrices ] \
- for et in edge_types ]
-
- if self.shuffle:
- edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
- for edge_lst in edge_lists ]
-
- offsets = [ [ 0 ] * len(et.adjacency_matrices) \
- for et in edge_types ]
-
- return (edge_keys, edge_types, edge_lists, offsets)
-
- def get_candidates(self, edge_lists, offsets):
- candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
- if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
- if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
-
- if len(candidates) == 0:
- return None, None
-
- edge_idx = torch.randint(0, len(candidates), (1,)).item()
- edge_idx = candidates[edge_idx]
- candidates = [ rel_idx \
- for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
- if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
-
- rel_idx = torch.randint(0, len(candidates), (1,)).item()
- rel_idx = candidates[rel_idx]
-
- return edge_idx, rel_idx
-
- def take_edges(self, edge_idx, rel_idx, edge_lists, offsets,
- edge_types, target_value):
-
- lst = edge_lists[edge_idx][rel_idx]
- et = edge_types[edge_idx]
- ofs = offsets[edge_idx][rel_idx]
- lst = lst[ofs:ofs+self.batch_size]
- offsets[edge_idx][rel_idx] += self.batch_size
-
- res = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
- rel_idx, lst, torch.full(( len(lst), ), target_value,
- dtype=torch.float32))
-
- return res
-
- def __iter__(self):
- pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \
- self.get_edge_lists(self.pos_data)
-
- neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \
- self.get_edge_lists(self.neg_data)
-
- while True:
- edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets)
-
- if edge_idx is None:
- return
-
- pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists,
- pos_offsets, pos_edge_types, 1)
-
- neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists,
- neg_offsets, neg_edge_types, 0)
-
- yield (pos_batch, neg_batch)
-
-
- class Batcher(object):
- def __init__(self, data: Data, batch_size: int=512,
- shuffle: bool=True) -> None:
-
- if not isinstance(data, Data):
- raise TypeError('data must be an instance of Data')
-
- self.data = data
- self.batch_size = int(batch_size)
- self.shuffle = bool(shuffle)
-
- def __iter__(self) -> TrainingBatch:
- edge_types = list(self.data.edge_types.values())
-
- edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
- for adj_mat in et.adjacency_matrices ] \
- for et in edge_types ]
-
- if self.shuffle:
- edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
- for edge_lst in edge_lists ]
-
- offsets = [ [ 0 ] * len(et.adjacency_matrices) \
- for et in edge_types ]
-
- while True:
- candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
- if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
- if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
- if len(candidates) == 0:
- break
-
- edge_idx = torch.randint(0, len(candidates), (1,)).item()
- edge_idx = candidates[edge_idx]
- candidates = [ rel_idx \
- for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
- if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
-
- rel_idx = torch.randint(0, len(candidates), (1,)).item()
- rel_idx = candidates[rel_idx]
-
- lst = edge_lists[edge_idx][rel_idx]
- et = edge_types[edge_idx]
- ofs = offsets[edge_idx][rel_idx]
- lst = lst[ofs:ofs+self.batch_size]
- offsets[edge_idx][rel_idx] += self.batch_size
-
- b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
- rel_idx, lst, torch.full((len(lst),), self.data.target_value,
- dtype=torch.float32))
-
- yield b
|