@@ -0,0 +1,62 @@ | |||
from .data import Data | |||
from .model import TrainingBatch | |||
import torch | |||
def _shuffle(x: torch.Tensor) -> torch.Tensor: | |||
order = torch.randperm(len(x)) | |||
return x[order] | |||
class Batcher(object): | |||
def __init__(self, data: Data, batch_size: int=512, | |||
shuffle: bool=True) -> None: | |||
if not isinstance(data, Data): | |||
raise TypeError('data must be an instance of Data') | |||
self.data = data | |||
self.batch_size = int(batch_size) | |||
self.shuffle = bool(shuffle) | |||
def __iter__(self) -> TrainingBatch: | |||
edge_types = list(self.data.edge_types.values()) | |||
edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ | |||
for adj_mat in et.adjacency_matrices ] \ | |||
for et in edge_types ] | |||
if self.shuffle: | |||
edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ | |||
for edge_lst in edge_lists ] | |||
offsets = [ [ 0 ] * len(et.adjacency_matrices) \ | |||
for et in edge_types ] | |||
while True: | |||
candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ | |||
if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ | |||
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] | |||
if len(candidates) == 0: | |||
break | |||
edge_idx = torch.randint(0, len(candidates), (1,)).item() | |||
edge_idx = candidates[edge_idx] | |||
candidates = [ rel_idx \ | |||
for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ | |||
if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] | |||
rel_idx = torch.randint(0, len(candidates), (1,)).item() | |||
rel_idx = candidates[rel_idx] | |||
lst = edge_lists[edge_idx][rel_idx] | |||
et = edge_types[edge_idx] | |||
ofs = offsets[edge_idx][rel_idx] | |||
lst = lst[ofs:ofs+self.batch_size] | |||
offsets[edge_idx][rel_idx] += self.batch_size | |||
b = TrainingBatch(et.vertex_type_row, et.vertex_type_column, | |||
rel_idx, lst, torch.full((len(lst),), self.data.target_value, | |||
dtype=torch.float32)) | |||
yield b |
@@ -39,9 +39,10 @@ class Data(object): | |||
vertex_types: List[VertexType] | |||
edge_types: List[EdgeType] | |||
def __init__(self) -> None: | |||
def __init__(self, target_value: int = 1) -> None: | |||
self.vertex_types = [] | |||
self.edge_types = {} | |||
self.target_value = int(target_value) | |||
def add_vertex_type(self, name: str, count: int) -> None: | |||
name = str(name) | |||
@@ -147,7 +147,7 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||
def negative_sample_data(data: Data) -> Data: | |||
new_edge_types = {} | |||
res = Data() | |||
res = Data(target_value=0) | |||
for vt in data.vertex_types: | |||
res.add_vertex_type(vt.name, vt.count) | |||
for key, et in data.edge_types.items(): | |||
@@ -0,0 +1,119 @@ | |||
from triacontagon.batch import Batcher | |||
from triacontagon.data import Data | |||
from triacontagon.decode import dedicom_decoder | |||
import torch | |||
def test_batcher_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||
torch.tensor([ | |||
[0, 1, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[1, 0, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0] | |||
]).to_sparse() | |||
], dedicom_decoder) | |||
b = Batcher(d, batch_size=1) | |||
visited = set() | |||
for t in b: | |||
print(t) | |||
k = tuple(t.edges[0].tolist()) | |||
visited.add(k) | |||
assert visited == { (0, 1), (0, 3), | |||
(1, 4), (2, 0), (3, 2), (4, 3) } | |||
def test_batcher_02(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||
torch.tensor([ | |||
[0, 1, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[1, 0, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0] | |||
]).to_sparse(), | |||
torch.tensor([ | |||
[1, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0] | |||
]).to_sparse() | |||
], dedicom_decoder) | |||
b = Batcher(d, batch_size=1) | |||
visited = set() | |||
for t in b: | |||
print(t) | |||
k = (t.relation_type_index,) + \ | |||
tuple(t.edges[0].tolist()) | |||
visited.add(k) | |||
assert visited == { (0, 0, 1), (0, 0, 3), | |||
(0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), | |||
(1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), | |||
(1, 3, 1), (1, 4, 2) } | |||
def test_batcher_03(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 5) | |||
d.add_vertex_type('Drug', 4) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||
torch.tensor([ | |||
[0, 1, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[1, 0, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0] | |||
]).to_sparse(), | |||
torch.tensor([ | |||
[1, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0] | |||
]).to_sparse() | |||
], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ | |||
torch.tensor([ | |||
[0, 1, 0, 0], | |||
[1, 0, 0, 1], | |||
[0, 1, 0, 0], | |||
[0, 0, 1, 0], | |||
[0, 1, 1, 0] | |||
]).to_sparse() | |||
], dedicom_decoder) | |||
b = Batcher(d, batch_size=1) | |||
visited = set() | |||
for t in b: | |||
print(t) | |||
k = (t.vertex_type_row, t.vertex_type_column, | |||
t.relation_type_index,) + \ | |||
tuple(t.edges[0].tolist()) | |||
visited.add(k) | |||
assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | |||
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | |||
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | |||
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | |||
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | |||
(0, 1, 0, 4, 2) } |