| @@ -0,0 +1,62 @@ | |||
| from .data import Data | |||
| from .model import TrainingBatch | |||
| import torch | |||
| def _shuffle(x: torch.Tensor) -> torch.Tensor: | |||
| order = torch.randperm(len(x)) | |||
| return x[order] | |||
| class Batcher(object): | |||
| def __init__(self, data: Data, batch_size: int=512, | |||
| shuffle: bool=True) -> None: | |||
| if not isinstance(data, Data): | |||
| raise TypeError('data must be an instance of Data') | |||
| self.data = data | |||
| self.batch_size = int(batch_size) | |||
| self.shuffle = bool(shuffle) | |||
| def __iter__(self) -> TrainingBatch: | |||
| edge_types = list(self.data.edge_types.values()) | |||
| edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ | |||
| for adj_mat in et.adjacency_matrices ] \ | |||
| for et in edge_types ] | |||
| if self.shuffle: | |||
| edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ | |||
| for edge_lst in edge_lists ] | |||
| offsets = [ [ 0 ] * len(et.adjacency_matrices) \ | |||
| for et in edge_types ] | |||
| while True: | |||
| candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ | |||
| if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ | |||
| if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] | |||
| if len(candidates) == 0: | |||
| break | |||
| edge_idx = torch.randint(0, len(candidates), (1,)).item() | |||
| edge_idx = candidates[edge_idx] | |||
| candidates = [ rel_idx \ | |||
| for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ | |||
| if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] | |||
| rel_idx = torch.randint(0, len(candidates), (1,)).item() | |||
| rel_idx = candidates[rel_idx] | |||
| lst = edge_lists[edge_idx][rel_idx] | |||
| et = edge_types[edge_idx] | |||
| ofs = offsets[edge_idx][rel_idx] | |||
| lst = lst[ofs:ofs+self.batch_size] | |||
| offsets[edge_idx][rel_idx] += self.batch_size | |||
| b = TrainingBatch(et.vertex_type_row, et.vertex_type_column, | |||
| rel_idx, lst, torch.full((len(lst),), self.data.target_value, | |||
| dtype=torch.float32)) | |||
| yield b | |||
| @@ -39,9 +39,10 @@ class Data(object): | |||
| vertex_types: List[VertexType] | |||
| edge_types: List[EdgeType] | |||
| def __init__(self) -> None: | |||
| def __init__(self, target_value: int = 1) -> None: | |||
| self.vertex_types = [] | |||
| self.edge_types = {} | |||
| self.target_value = int(target_value) | |||
| def add_vertex_type(self, name: str, count: int) -> None: | |||
| name = str(name) | |||
| @@ -147,7 +147,7 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| def negative_sample_data(data: Data) -> Data: | |||
| new_edge_types = {} | |||
| res = Data() | |||
| res = Data(target_value=0) | |||
| for vt in data.vertex_types: | |||
| res.add_vertex_type(vt.name, vt.count) | |||
| for key, et in data.edge_types.items(): | |||
| @@ -0,0 +1,119 @@ | |||
| from triacontagon.batch import Batcher | |||
| from triacontagon.data import Data | |||
| from triacontagon.decode import dedicom_decoder | |||
| import torch | |||
| def test_batcher_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 5) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ | |||
| torch.tensor([ | |||
| [0, 1, 0, 1, 0], | |||
| [0, 0, 0, 0, 1], | |||
| [1, 0, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0] | |||
| ]).to_sparse() | |||
| ], dedicom_decoder) | |||
| b = Batcher(d, batch_size=1) | |||
| visited = set() | |||
| for t in b: | |||
| print(t) | |||
| k = tuple(t.edges[0].tolist()) | |||
| visited.add(k) | |||
| assert visited == { (0, 1), (0, 3), | |||
| (1, 4), (2, 0), (3, 2), (4, 3) } | |||
| def test_batcher_02(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 5) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ | |||
| torch.tensor([ | |||
| [0, 1, 0, 1, 0], | |||
| [0, 0, 0, 0, 1], | |||
| [1, 0, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0] | |||
| ]).to_sparse(), | |||
| torch.tensor([ | |||
| [1, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0], | |||
| [0, 0, 0, 0, 1], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 0] | |||
| ]).to_sparse() | |||
| ], dedicom_decoder) | |||
| b = Batcher(d, batch_size=1) | |||
| visited = set() | |||
| for t in b: | |||
| print(t) | |||
| k = (t.relation_type_index,) + \ | |||
| tuple(t.edges[0].tolist()) | |||
| visited.add(k) | |||
| assert visited == { (0, 0, 1), (0, 0, 3), | |||
| (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), | |||
| (1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), | |||
| (1, 3, 1), (1, 4, 2) } | |||
| def test_batcher_03(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 5) | |||
| d.add_vertex_type('Drug', 4) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ | |||
| torch.tensor([ | |||
| [0, 1, 0, 1, 0], | |||
| [0, 0, 0, 0, 1], | |||
| [1, 0, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0] | |||
| ]).to_sparse(), | |||
| torch.tensor([ | |||
| [1, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0], | |||
| [0, 0, 0, 0, 1], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 0] | |||
| ]).to_sparse() | |||
| ], dedicom_decoder) | |||
| d.add_edge_type('Gene-Drug', 0, 1, [ | |||
| torch.tensor([ | |||
| [0, 1, 0, 0], | |||
| [1, 0, 0, 1], | |||
| [0, 1, 0, 0], | |||
| [0, 0, 1, 0], | |||
| [0, 1, 1, 0] | |||
| ]).to_sparse() | |||
| ], dedicom_decoder) | |||
| b = Batcher(d, batch_size=1) | |||
| visited = set() | |||
| for t in b: | |||
| print(t) | |||
| k = (t.vertex_type_row, t.vertex_type_column, | |||
| t.relation_type_index,) + \ | |||
| tuple(t.edges[0].tolist()) | |||
| visited.add(k) | |||
| assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), | |||
| (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), | |||
| (0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), | |||
| (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), | |||
| (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), | |||
| (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), | |||
| (0, 1, 0, 4, 2) } | |||