IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Finish DualBatcher.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
6b15ec8c10
共有 1 個文件被更改,包括 71 次插入1 次删除
  1. +71
    -1
      src/triacontagon/batch.py

+ 71
- 1
src/triacontagon/batch.py 查看文件

@@ -21,7 +21,7 @@ def _same_data_org(pos_data: Data, neg_data: Data):
if not set(pos_data.edge_types.keys()) == \
set(neg_data.edge_types.keys()):
return False
test = [ pos_data.edge_types[i].name == \
@@ -65,8 +65,78 @@ class DualBatcher(object):
self.batch_size = int(batch_size)
self.shuffle = bool(shuffle)
def get_edge_lists(self, data: Data):
edge_types = list(data.edge_types.items())
edge_keys = [ a[0] for a in edge_types ]
edge_types = [ a[1] for a in edge_types ]
edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
for adj_mat in et.adjacency_matrices ] \
for et in edge_types ]
if self.shuffle:
edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
for edge_lst in edge_lists ]
offsets = [ [ 0 ] * len(et.adjacency_matrices) \
for et in edge_types ]
return (edge_keys, edge_types, edge_lists, offsets)
def get_candidates(self, edge_lists, offsets):
candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
if len(candidates) == 0:
return None, None
edge_idx = torch.randint(0, len(candidates), (1,)).item()
edge_idx = candidates[edge_idx]
candidates = [ rel_idx \
for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
rel_idx = torch.randint(0, len(candidates), (1,)).item()
rel_idx = candidates[rel_idx]
return edge_idx, rel_idx
def take_edges(self, edge_idx, rel_idx, edge_lists, offsets,
edge_types, target_value):
lst = edge_lists[edge_idx][rel_idx]
et = edge_types[edge_idx]
ofs = offsets[edge_idx][rel_idx]
lst = lst[ofs:ofs+self.batch_size]
offsets[edge_idx][rel_idx] += self.batch_size
res = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
rel_idx, lst, torch.full(len(lst), target_value,
dtype=torch.float32))
return res
def __iter__(self):
pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \
self.get_edge_lists(self.pos_data)
neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \
self.get_edge_lists(self.neg_data)
while True:
edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets)
if edge_idx is None:
return
pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists,
pos_offsets, pos_edge_types, 1)
neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists,
neg_offsets, neg_edge_types, 0)
yield (pos_batch, neg_batch)
class Batcher(object):


Loading…
取消
儲存