from icosagon.databatch import DataBatcher, \ BatchedData from icosagon.data import Data from icosagon.trainprep import prepare_training, \ TrainValTest import torch def _some_data(): data = Data() data.add_node_type('Foo', 100) data.add_node_type('Bar', 500) fam = data.add_relation_family('Foo-Bar', 0, 1, True) adj_mat = torch.rand(100, 500).round().to_sparse() fam.add_relation_type('Foo-Bar', adj_mat) return data def test_data_batcher_01(): data = _some_data() prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) batcher = DataBatcher(prep_d, 512) def test_data_batcher_02(): data = _some_data() prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) batcher = DataBatcher(prep_d, 512) for batch_d in batcher: pass def test_data_batcher_03(): data = _some_data() prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) batcher = DataBatcher(prep_d, 512) for batch_d in batcher: edges_list = [] for fam in batch_d.relation_families: for rel in fam.relation_types: for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: for part_type in ['train', 'val', 'test']: edges = getattr(getattr(rel, edge_type), part_type) edges_list.append(edges) assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1 def test_data_batcher_04(): data = _some_data() prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) batcher = DataBatcher(prep_d, 512) edges_list = [] for batch_d in batcher: for fam in batch_d.relation_families: for rel in fam.relation_types: for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: for part_type in ['train', 'val', 'test']: edges = getattr(getattr(rel, edge_type), part_type) edges_list.append(edges) assert sum([ len(edges) for edges in edges_list ]) == \ torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2 def test_data_batcher_05(): data = _some_data() prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) batcher = DataBatcher(prep_d, 512) for batch_d in batcher: edges_list = [] for fam in batch_d.relation_families: for rel in fam.relation_types: for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']: for part_type in ['train', 'val', 'test']: edges = getattr(getattr(rel, edge_type), part_type) edges_list.append(edges) assert all([ len(edges) <= 512 for edges in edges_list ]) assert not all([ len(edges) == 0 for edges in edges_list ]) print(sum(map(len, edges_list)))