diff --git a/src/triacontagon/batch.py b/src/triacontagon/batch.py index ec66f5f..9688d03 100644 --- a/src/triacontagon/batch.py +++ b/src/triacontagon/batch.py @@ -38,9 +38,9 @@ def _same_data_org(pos_data: Data, neg_data: Data): test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \ len(neg_data.edge_types[i].adjacency_matrices[k].values()) \ - for k in range(len(pos_data.edge_types[i])) ] \ + for k in range(len(pos_data.edge_types[i].adjacency_matrices)) ] \ for i in pos_data.edge_types.keys() ] - test = reduce(list.__add__, test) + test = reduce(list.__add__, test, []) if not all(test): return False @@ -112,7 +112,7 @@ class DualBatcher(object): offsets[edge_idx][rel_idx] += self.batch_size res = TrainingBatch(et.vertex_type_row, et.vertex_type_column, - rel_idx, lst, torch.full(len(lst), target_value, + rel_idx, lst, torch.full(( len(lst), ), target_value, dtype=torch.float32)) return res diff --git a/tests/triacontagon/test_batch.py b/tests/triacontagon/test_batch.py index f145380..46ce6a3 100644 --- a/tests/triacontagon/test_batch.py +++ b/tests/triacontagon/test_batch.py @@ -1,9 +1,59 @@ -from triacontagon.batch import Batcher +from triacontagon.batch import _same_data_org, \ + DualBatcher, \ + Batcher from triacontagon.data import Data from triacontagon.decode import dedicom_decoder import torch +def test_same_data_org_01(): + data = Data() + assert _same_data_org(data, data) + + data.add_vertex_type('Foo', 10) + assert _same_data_org(data, data) + + data.add_vertex_type('Bar', 10) + assert _same_data_org(data, data) + + data_1 = Data() + assert not _same_data_org(data, data_1) + + data_1.add_vertex_type('Foo', 10) + assert not _same_data_org(data, data_1) + + data_1.add_vertex_type('Bar', 10) + assert _same_data_org(data, data_1) + + +def test_same_data_org_02(): + data = Data() + data.add_vertex_type('Foo', 4) + data.add_edge_type('Foo-Foo', 0, 0, [ + torch.tensor([ + [0, 0, 0, 1], + [1, 0, 0, 0], + [0, 1, 1, 0], + [1, 0, 1, 0] + ]).to_sparse() + ], dedicom_decoder) + + assert _same_data_org(data, data) + + data_1 = Data() + data_1.add_vertex_type('Foo', 4) + data_1.add_edge_type('Foo-Foo', 0, 0, [ + torch.tensor([ + [0, 0, 0, 1], + [1, 0, 0, 0], + [0, 1, 1, 0], + [1, 0, 0, 0] + ]).to_sparse() + ], dedicom_decoder) + + assert not _same_data_org(data, data_1) + + def test_batcher_01(): d = Data() d.add_vertex_type('Gene', 5) @@ -197,3 +247,70 @@ def test_batcher_05(): (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), (0, 1, 0, 4, 2) } + + +def test_dual_batcher_01(): + d = Data() + d.add_vertex_type('Gene', 5) + d.add_vertex_type('Drug', 4) + + d.add_edge_type('Gene-Gene', 0, 0, [ + torch.tensor([ + [0, 1, 0, 1, 0], + [0, 0, 0, 0, 1], + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0] + ]).to_sparse(), + + torch.tensor([ + [1, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0] + ]).to_sparse() + ], dedicom_decoder) + + d.add_edge_type('Gene-Drug', 0, 1, [ + torch.tensor([ + [0, 1, 0, 0], + [1, 0, 0, 1], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 1, 1, 0] + ]).to_sparse() + ], dedicom_decoder) + + b = DualBatcher(d, d, batch_size=5) + + visited_pos = set() + visited_neg = set() + for t_pos, t_neg in b: + assert t_pos.vertex_type_row == t_neg.vertex_type_row + assert t_pos.vertex_type_column == t_neg.vertex_type_column + assert t_pos.relation_type_index == t_neg.relation_type_index + assert len(t_pos.edges) == len(t_neg.edges) + + for e in t_pos.edges: + k = (t_pos.vertex_type_row, t_pos.vertex_type_column, + t_pos.relation_type_index,) + \ + tuple(e.tolist()) + visited_pos.add(k) + + for e in t_neg.edges: + k = (t_neg.vertex_type_row, t_neg.vertex_type_column, + t_neg.relation_type_index,) + \ + tuple(e.tolist()) + visited_neg.add(k) + + expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3), + (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3), + (0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4), + (0, 0, 1, 3, 1), (0, 0, 1, 4, 2), + (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3), + (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1), + (0, 1, 0, 4, 2) } + + assert visited_pos == expected + assert visited_neg == expected