|
|
@@ -21,7 +21,7 @@ def _same_data_org(pos_data: Data, neg_data: Data): |
|
|
|
|
|
|
|
if not set(pos_data.edge_types.keys()) == \
|
|
|
|
set(neg_data.edge_types.keys()):
|
|
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
test = [ pos_data.edge_types[i].name == \
|
|
|
@@ -65,8 +65,78 @@ class DualBatcher(object): |
|
|
|
self.batch_size = int(batch_size)
|
|
|
|
self.shuffle = bool(shuffle)
|
|
|
|
|
|
|
|
def get_edge_lists(self, data: Data):
|
|
|
|
edge_types = list(data.edge_types.items())
|
|
|
|
edge_keys = [ a[0] for a in edge_types ]
|
|
|
|
edge_types = [ a[1] for a in edge_types ]
|
|
|
|
|
|
|
|
edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
|
|
|
|
for adj_mat in et.adjacency_matrices ] \
|
|
|
|
for et in edge_types ]
|
|
|
|
|
|
|
|
if self.shuffle:
|
|
|
|
edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
|
|
|
|
for edge_lst in edge_lists ]
|
|
|
|
|
|
|
|
offsets = [ [ 0 ] * len(et.adjacency_matrices) \
|
|
|
|
for et in edge_types ]
|
|
|
|
|
|
|
|
return (edge_keys, edge_types, edge_lists, offsets)
|
|
|
|
|
|
|
|
def get_candidates(self, edge_lists, offsets):
|
|
|
|
candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
|
|
|
|
if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
|
|
|
|
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
|
|
|
|
|
|
|
|
if len(candidates) == 0:
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
edge_idx = torch.randint(0, len(candidates), (1,)).item()
|
|
|
|
edge_idx = candidates[edge_idx]
|
|
|
|
candidates = [ rel_idx \
|
|
|
|
for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
|
|
|
|
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
|
|
|
|
|
|
|
|
rel_idx = torch.randint(0, len(candidates), (1,)).item()
|
|
|
|
rel_idx = candidates[rel_idx]
|
|
|
|
|
|
|
|
return edge_idx, rel_idx
|
|
|
|
|
|
|
|
def take_edges(self, edge_idx, rel_idx, edge_lists, offsets,
|
|
|
|
edge_types, target_value):
|
|
|
|
|
|
|
|
lst = edge_lists[edge_idx][rel_idx]
|
|
|
|
et = edge_types[edge_idx]
|
|
|
|
ofs = offsets[edge_idx][rel_idx]
|
|
|
|
lst = lst[ofs:ofs+self.batch_size]
|
|
|
|
offsets[edge_idx][rel_idx] += self.batch_size
|
|
|
|
|
|
|
|
res = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
|
|
|
|
rel_idx, lst, torch.full(len(lst), target_value,
|
|
|
|
dtype=torch.float32))
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \
|
|
|
|
self.get_edge_lists(self.pos_data)
|
|
|
|
|
|
|
|
neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \
|
|
|
|
self.get_edge_lists(self.neg_data)
|
|
|
|
|
|
|
|
while True:
|
|
|
|
edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets)
|
|
|
|
|
|
|
|
if edge_idx is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists,
|
|
|
|
pos_offsets, pos_edge_types, 1)
|
|
|
|
|
|
|
|
neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists,
|
|
|
|
neg_offsets, neg_edge_types, 0)
|
|
|
|
|
|
|
|
yield (pos_batch, neg_batch)
|
|
|
|
|
|
|
|
|
|
|
|
class Batcher(object):
|
|
|
|