diff --git a/docs/cumcount.svg b/docs/cumcount.svg new file mode 100644 index 0000000..4da5428 --- /dev/null +++ b/docs/cumcount.svg @@ -0,0 +1,2074 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 1 + 2 + 3 + 3 + 4 + 1 + 1 + 5 + 5 + 2 + 3 + 1 + 4 + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + s=argsort: + i=unargsort: + + + + + + + + + + + + + + + 0 + 1 + 5 + 7 + 8 + 10 + 2 + 3 + 12 + 13 + 6 + 9 + 4 + 11 + b=a[s]: + + + + + + + + + + + + + + + 1 + 1 + 2 + 3 + 3 + 4 + 1 + 1 + 5 + 5 + 2 + 3 + 1 + 4 + dfill(b): + + + + + + + + + + + + + + + 1 + 1 + 1 + 1 + 5 + 7 + 10 + 12 + 0 + ...diff: + 5,2,3,2,2 + ...repeat: + 0,0,0,0,0,5,5,7,7,7,10,10,12,12 + ...where: + arange(n)-dfill(b): + 0,1,2,3,4,0,1,0,1,2,0,1,0,1 + (arange(n)-dfill(b))[i]: + + + + + + + + + + + + + + + 0 + 1 + 0 + 0 + 1 + 0 + 2 + 3 + 0 + 1 + 1 + 2 + 4 + 1 + + + + + + + + + + + + + + + 1 + 1 + 2 + 3 + 3 + 4 + 1 + 1 + 5 + 5 + 2 + 3 + 1 + 4 + in: + in: + + diff --git a/src/triacontagon/batch.py b/src/triacontagon/batch.py new file mode 100644 index 0000000..cfb367e --- /dev/null +++ b/src/triacontagon/batch.py @@ -0,0 +1,62 @@ +from .data import Data +from .model import TrainingBatch +import torch + + +def _shuffle(x: torch.Tensor) -> torch.Tensor: + order = torch.randperm(len(x)) + return x[order] + + +class Batcher(object): + def __init__(self, data: Data, batch_size: int=512, + shuffle: bool=True) -> None: + + if not isinstance(data, Data): + raise TypeError('data must be an instance of Data') + + self.data = data + self.batch_size = int(batch_size) + self.shuffle = bool(shuffle) + + def __iter__(self) -> TrainingBatch: + edge_types = list(self.data.edge_types.values()) + + edge_lists = [ [ adj_mat.indices().transpose(0, 1) \ + for adj_mat in et.adjacency_matrices ] \ + for et in edge_types ] + + if self.shuffle: + edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \ + for edge_lst in edge_lists ] + + offsets = [ [ 0 ] * len(et.adjacency_matrices) \ + for et in edge_types ] + + while True: + candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \ + if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \ + if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ] + if len(candidates) == 0: + break + + edge_idx = torch.randint(0, len(candidates), (1,)).item() + edge_idx = candidates[edge_idx] + candidates = [ rel_idx \ + for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \ + if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ] + + rel_idx = torch.randint(0, len(candidates), (1,)).item() + rel_idx = candidates[rel_idx] + + lst = edge_lists[edge_idx][rel_idx] + et = edge_types[edge_idx] + ofs = offsets[edge_idx][rel_idx] + lst = lst[ofs:ofs+self.batch_size] + offsets[edge_idx][rel_idx] += self.batch_size + + b = TrainingBatch(et.vertex_type_row, et.vertex_type_column, + rel_idx, lst, torch.full((len(lst),), self.data.target_value, + dtype=torch.float32)) + + yield b diff --git a/src/triacontagon/data.py b/src/triacontagon/data.py index ba2b7f8..3c13a50 100644 --- a/src/triacontagon/data.py +++ b/src/triacontagon/data.py @@ -39,9 +39,10 @@ class Data(object): vertex_types: List[VertexType] edge_types: List[EdgeType] - def __init__(self) -> None: + def __init__(self, target_value: int = 1) -> None: self.vertex_types = [] self.edge_types = {} + self.target_value = int(target_value) def add_vertex_type(self, name: str, count: int) -> None: name = str(name) diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index 58b7ba0..c85ee1d 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -147,7 +147,7 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: def negative_sample_data(data: Data) -> Data: new_edge_types = {} - res = Data() + res = Data(target_value=0) for vt in data.vertex_types: res.add_vertex_type(vt.name, vt.count) for key, et in data.edge_types.items(): diff --git a/src/triacontagon/trainprep.py b/src/triacontagon/split.py similarity index 100% rename from src/triacontagon/trainprep.py rename to src/triacontagon/split.py diff --git a/tests/triacontagon/test_batch.py b/tests/triacontagon/test_batch.py new file mode 100644 index 0000000..3717832 --- /dev/null +++ b/tests/triacontagon/test_batch.py @@ -0,0 +1,119 @@ +from triacontagon.batch import Batcher +from triacontagon.data import Data +from triacontagon.decode import dedicom_decoder +import torch + + +def test_batcher_01(): + d = Data() + d.add_vertex_type('Gene', 5) + + d.add_edge_type('Gene-Gene', 0, 0, [ + torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0] + ]).to_sparse() + ], dedicom_decoder) + + b = Batcher(d, batch_size=1) + + visited = set() + for t in b: + print(t) + k = tuple(t.edges[0].tolist()) + visited.add(k) + + assert visited == { (0, 1), (0, 3), + (1, 4), (2, 0), (3, 2), (4, 3) } + + +def test_batcher_02(): + d = Data() + d.add_vertex_type('Gene', 5) + + d.add_edge_type('Gene-Gene', 0, 0, [ + torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0] + ]).to_sparse(), + + torch.tensor([ + [1, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0] + ]).to_sparse() + ], dedicom_decoder) + + b = Batcher(d, batch_size=1) + + visited = set() + for t in b: + print(t) + k = (t.relation_type_index,) + \ + tuple(t.edges[0].tolist()) + visited.add(k) + + assert visited == { (0, 0, 1), (0, 0, 3), + (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3), + (1, 0, 0), (1, 0, 2), (1, 1, 3), (1, 2, 4), + (1, 3, 1), (1, 4, 2) } + + +def test_batcher_03(): + d = Data() + d.add_vertex_type('Gene', 5) + d.add_vertex_type('Drug', 4) + + d.add_edge_type('Gene-Gene', 0, 0, [ + torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0] + ]).to_sparse(), + + torch.tensor([ + [1, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0] + ]).to_sparse() + ], dedicom_decoder) + + d.add_edge_type('Gene-Drug', 0, 1, [ + torch.tensor([ + [0, 1, 0, 0], + [1, 0, 0, 1], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 1, 1, 0] + ]).to_sparse() + ], dedicom_decoder) + + b = Batcher(d, batch_size=1) + + visited = set() + for t in b: + print(t) + k = (t.vertex_type_row, t.vertex_type_column, + t.relation_type_index,) + \ + tuple(t.edges[0].tolist()) + visited.add(k) + + assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), + (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), + (0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), + (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), + (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), + (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), + (0, 1, 0, 4, 2) }