@@ -0,0 +1,62 @@ | |||||
from .data import Data | |||||
from .model import TrainingBatch | |||||
import torch | |||||
def _shuffle(x: torch.Tensor) -> torch.Tensor: | |||||
order = torch.randperm(len(x)) | |||||
return x[order] | |||||
class Batcher(object): | |||||
def __init__(self, data: Data, batch_size: int=512, | |||||
shuffle: bool=True) -> None: | |||||
if not isinstance(data, Data): | |||||
raise TypeError('data must be an instance of Data') | |||||
self.data = data | |||||
self.batch_size = int(batch_size) | |||||
self.shuffle = bool(shuffle) | |||||
def __iter__(self) -> TrainingBatch: | |||||
edge_types = list(self.data.edge_types.values()) | |||||
edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ | |||||
for adj_mat in et.adjacency_matrices ] \ | |||||
for et in edge_types ] | |||||
if self.shuffle: | |||||
edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ | |||||
for edge_lst in edge_lists ] | |||||
offsets = [ [ 0 ] * len(et.adjacency_matrices) \ | |||||
for et in edge_types ] | |||||
while True: | |||||
candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ | |||||
if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ | |||||
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] | |||||
if len(candidates) == 0: | |||||
break | |||||
edge_idx = torch.randint(0, len(candidates), (1,)).item() | |||||
edge_idx = candidates[edge_idx] | |||||
candidates = [ rel_idx \ | |||||
for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ | |||||
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] | |||||
rel_idx = torch.randint(0, len(candidates), (1,)).item() | |||||
rel_idx = candidates[rel_idx] | |||||
lst = edge_lists[edge_idx][rel_idx] | |||||
et = edge_types[edge_idx] | |||||
ofs = offsets[edge_idx][rel_idx] | |||||
lst = lst[ofs:ofs+self.batch_size] | |||||
offsets[edge_idx][rel_idx] += self.batch_size | |||||
b = TrainingBatch(et.vertex_type_row, et.vertex_type_column, | |||||
rel_idx, lst, torch.full((len(lst),), self.data.target_value, | |||||
dtype=torch.float32)) | |||||
yield b |
@@ -39,9 +39,10 @@ class Data(object): | |||||
vertex_types: List[VertexType] | vertex_types: List[VertexType] | ||||
edge_types: List[EdgeType] | edge_types: List[EdgeType] | ||||
def __init__(self) -> None: | |||||
def __init__(self, target_value: int = 1) -> None: | |||||
self.vertex_types = [] | self.vertex_types = [] | ||||
self.edge_types = {} | self.edge_types = {} | ||||
self.target_value = int(target_value) | |||||
def add_vertex_type(self, name: str, count: int) -> None: | def add_vertex_type(self, name: str, count: int) -> None: | ||||
name = str(name) | name = str(name) | ||||
@@ -147,7 +147,7 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
def negative_sample_data(data: Data) -> Data: | def negative_sample_data(data: Data) -> Data: | ||||
new_edge_types = {} | new_edge_types = {} | ||||
res = Data() | |||||
res = Data(target_value=0) | |||||
for vt in data.vertex_types: | for vt in data.vertex_types: | ||||
res.add_vertex_type(vt.name, vt.count) | res.add_vertex_type(vt.name, vt.count) | ||||
for key, et in data.edge_types.items(): | for key, et in data.edge_types.items(): | ||||
@@ -0,0 +1,119 @@ | |||||
from triacontagon.batch import Batcher | |||||
from triacontagon.data import Data | |||||
from triacontagon.decode import dedicom_decoder | |||||
import torch | |||||
def test_batcher_01(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 5) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||||
torch.tensor([ | |||||
[0, 1, 0, 1, 0], | |||||
[0, 0, 0, 0, 1], | |||||
[1, 0, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0] | |||||
]).to_sparse() | |||||
], dedicom_decoder) | |||||
b = Batcher(d, batch_size=1) | |||||
visited = set() | |||||
for t in b: | |||||
print(t) | |||||
k = tuple(t.edges[0].tolist()) | |||||
visited.add(k) | |||||
assert visited == { (0, 1), (0, 3), | |||||
(1, 4), (2, 0), (3, 2), (4, 3) } | |||||
def test_batcher_02(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 5) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||||
torch.tensor([ | |||||
[0, 1, 0, 1, 0], | |||||
[0, 0, 0, 0, 1], | |||||
[1, 0, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0] | |||||
]).to_sparse(), | |||||
torch.tensor([ | |||||
[1, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 0, 0, 1], | |||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0] | |||||
]).to_sparse() | |||||
], dedicom_decoder) | |||||
b = Batcher(d, batch_size=1) | |||||
visited = set() | |||||
for t in b: | |||||
print(t) | |||||
k = (t.relation_type_index,) + \ | |||||
tuple(t.edges[0].tolist()) | |||||
visited.add(k) | |||||
assert visited == { (0, 0, 1), (0, 0, 3), | |||||
(0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), | |||||
(1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), | |||||
(1, 3, 1), (1, 4, 2) } | |||||
def test_batcher_03(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 5) | |||||
d.add_vertex_type('Drug', 4) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||||
torch.tensor([ | |||||
[0, 1, 0, 1, 0], | |||||
[0, 0, 0, 0, 1], | |||||
[1, 0, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0] | |||||
]).to_sparse(), | |||||
torch.tensor([ | |||||
[1, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 0, 0, 1], | |||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0] | |||||
]).to_sparse() | |||||
], dedicom_decoder) | |||||
d.add_edge_type('Gene-Drug', 0, 1, [ | |||||
torch.tensor([ | |||||
[0, 1, 0, 0], | |||||
[1, 0, 0, 1], | |||||
[0, 1, 0, 0], | |||||
[0, 0, 1, 0], | |||||
[0, 1, 1, 0] | |||||
]).to_sparse() | |||||
], dedicom_decoder) | |||||
b = Batcher(d, batch_size=1) | |||||
visited = set() | |||||
for t in b: | |||||
print(t) | |||||
k = (t.vertex_type_row, t.vertex_type_column, | |||||
t.relation_type_index,) + \ | |||||
tuple(t.edges[0].tolist()) | |||||
visited.add(k) | |||||
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | |||||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | |||||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | |||||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | |||||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | |||||
(0, 1, 0, 4, 2) } |