IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add Batcher.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
14e3d14b36
共有 6 个文件被更改,包括 2258 次插入2 次删除
  1. +2074
    -0
      docs/cumcount.svg
  2. +62
    -0
      src/triacontagon/batch.py
  3. +2
    -1
      src/triacontagon/data.py
  4. +1
    -1
      src/triacontagon/sampling.py
  5. +0
    -0
      src/triacontagon/split.py
  6. +119
    -0
      tests/triacontagon/test_batch.py

+ 2074
- 0
docs/cumcount.svg
文件差异内容过多而无法显示
查看文件


+ 62
- 0
src/triacontagon/batch.py 查看文件

@@ -0,0 +1,62 @@
from .data import Data
from .model import TrainingBatch
import torch
def _shuffle(x: torch.Tensor) -> torch.Tensor:
order = torch.randperm(len(x))
return x[order]
class Batcher(object):
def __init__(self, data: Data, batch_size: int=512,
shuffle: bool=True) -> None:
if not isinstance(data, Data):
raise TypeError('data must be an instance of Data')
self.data = data
self.batch_size = int(batch_size)
self.shuffle = bool(shuffle)
def __iter__(self) -> TrainingBatch:
edge_types = list(self.data.edge_types.values())
edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
for adj_mat in et.adjacency_matrices ] \
for et in edge_types ]
if self.shuffle:
edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
for edge_lst in edge_lists ]
offsets = [ [ 0 ] * len(et.adjacency_matrices) \
for et in edge_types ]
while True:
candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
if len(candidates) == 0:
break
edge_idx = torch.randint(0, len(candidates), (1,)).item()
edge_idx = candidates[edge_idx]
candidates = [ rel_idx \
for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
rel_idx = torch.randint(0, len(candidates), (1,)).item()
rel_idx = candidates[rel_idx]
lst = edge_lists[edge_idx][rel_idx]
et = edge_types[edge_idx]
ofs = offsets[edge_idx][rel_idx]
lst = lst[ofs:ofs+self.batch_size]
offsets[edge_idx][rel_idx] += self.batch_size
b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
rel_idx, lst, torch.full((len(lst),), self.data.target_value,
dtype=torch.float32))
yield b

+ 2
- 1
src/triacontagon/data.py 查看文件

@@ -39,9 +39,10 @@ class Data(object):
vertex_types: List[VertexType]
edge_types: List[EdgeType]
def __init__(self) -> None:
def __init__(self, target_value: int = 1) -> None:
self.vertex_types = []
self.edge_types = {}
self.target_value = int(target_value)
def add_vertex_type(self, name: str, count: int) -> None:
name = str(name)


+ 1
- 1
src/triacontagon/sampling.py 查看文件

@@ -147,7 +147,7 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
def negative_sample_data(data: Data) -> Data:
new_edge_types = {}
res = Data()
res = Data(target_value=0)
for vt in data.vertex_types:
res.add_vertex_type(vt.name, vt.count)
for key, et in data.edge_types.items():


src/triacontagon/trainprep.py → src/triacontagon/split.py 查看文件


+ 119
- 0
tests/triacontagon/test_batch.py 查看文件

@@ -0,0 +1,119 @@
from triacontagon.batch import Batcher
from triacontagon.data import Data
from triacontagon.decode import dedicom_decoder
import torch
def test_batcher_01():
d = Data()
d.add_vertex_type('Gene', 5)
d.add_edge_type('Gene-Gene', 0, 0, [
torch.tensor([
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0]
]).to_sparse()
], dedicom_decoder)
b = Batcher(d, batch_size=1)
visited = set()
for t in b:
print(t)
k = tuple(t.edges[0].tolist())
visited.add(k)
assert visited == { (0, 1), (0, 3),
(1, 4), (2, 0), (3, 2), (4, 3) }
def test_batcher_02():
d = Data()
d.add_vertex_type('Gene', 5)
d.add_edge_type('Gene-Gene', 0, 0, [
torch.tensor([
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0]
]).to_sparse(),
torch.tensor([
[1, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]
]).to_sparse()
], dedicom_decoder)
b = Batcher(d, batch_size=1)
visited = set()
for t in b:
print(t)
k = (t.relation_type_index,) + \
tuple(t.edges[0].tolist())
visited.add(k)
assert visited == { (0, 0, 1), (0, 0, 3),
(0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3),
(1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4),
(1, 3, 1), (1, 4, 2) }
def test_batcher_03():
d = Data()
d.add_vertex_type('Gene', 5)
d.add_vertex_type('Drug', 4)
d.add_edge_type('Gene-Gene', 0, 0, [
torch.tensor([
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0]
]).to_sparse(),
torch.tensor([
[1, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]
]).to_sparse()
], dedicom_decoder)
d.add_edge_type('Gene-Drug', 0, 1, [
torch.tensor([
[0, 1, 0, 0],
[1, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 1, 1, 0]
]).to_sparse()
], dedicom_decoder)
b = Batcher(d, batch_size=1)
visited = set()
for t in b:
print(t)
k = (t.vertex_type_row, t.vertex_type_column,
t.relation_type_index,) + \
tuple(t.edges[0].tolist())
visited.add(k)
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
(0, 1, 0, 4, 2) }

正在加载...
取消
保存