|
- from icosagon.databatch import DataBatcher, \
- BatchedData
- from icosagon.data import Data
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- import torch
-
-
- def _some_data():
- data = Data()
- data.add_node_type('Foo', 100)
- data.add_node_type('Bar', 500)
- fam = data.add_relation_family('Foo-Bar', 0, 1, True)
- adj_mat = torch.rand(100, 500).round().to_sparse()
- fam.add_relation_type('Foo-Bar', adj_mat)
- return data
-
-
- def test_data_batcher_01():
- data = _some_data()
- prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
- batcher = DataBatcher(prep_d, 512)
-
-
- def test_data_batcher_02():
- data = _some_data()
- prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
- batcher = DataBatcher(prep_d, 512)
- for batch_d in batcher:
- pass
-
-
- def test_data_batcher_03():
- data = _some_data()
- prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
- batcher = DataBatcher(prep_d, 512)
- for batch_d in batcher:
- edges_list = []
- for fam in batch_d.relation_families:
- for rel in fam.relation_types:
- for edge_type in ['edges_pos', 'edges_neg',
- 'edges_back_pos', 'edges_back_neg']:
- for part_type in ['train', 'val', 'test']:
- edges = getattr(getattr(rel, edge_type), part_type)
- edges_list.append(edges)
- assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1
-
-
- def test_data_batcher_04():
- data = _some_data()
- prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
- batcher = DataBatcher(prep_d, 512)
- edges_list = []
- for batch_d in batcher:
- for fam in batch_d.relation_families:
- for rel in fam.relation_types:
- for edge_type in ['edges_pos', 'edges_neg',
- 'edges_back_pos', 'edges_back_neg']:
- for part_type in ['train', 'val', 'test']:
- edges = getattr(getattr(rel, edge_type), part_type)
- edges_list.append(edges)
- assert sum([ len(edges) for edges in edges_list ]) == \
- torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2
-
-
- def test_data_batcher_05():
- data = _some_data()
- prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
- batcher = DataBatcher(prep_d, 512)
- for batch_d in batcher:
- edges_list = []
- for fam in batch_d.relation_families:
- for rel in fam.relation_types:
- for edge_type in ['edges_pos', 'edges_neg',
- 'edges_back_pos', 'edges_back_neg']:
- for part_type in ['train', 'val', 'test']:
- edges = getattr(getattr(rel, edge_type), part_type)
- edges_list.append(edges)
- assert all([ len(edges) <= 512 for edges in edges_list ])
- assert not all([ len(edges) == 0 for edges in edges_list ])
- print(sum(map(len, edges_list)))
|