| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 | 
                        - from icosagon.databatch import DataBatcher, \
 -     BatchedData, \
 -     BatchedDataPointer, \
 -     batched_data_skeleton
 - from icosagon.data import Data
 - from icosagon.trainprep import prepare_training, \
 -     TrainValTest
 - from icosagon.declayer import DecodeLayer
 - from icosagon.input import OneHotInputLayer
 - import torch
 - import time
 - 
 - 
 - def _some_data():
 -     data = Data()
 -     data.add_node_type('Foo', 100)
 -     data.add_node_type('Bar', 500)
 -     fam = data.add_relation_family('Foo-Bar', 0, 1, True)
 -     adj_mat = torch.rand(100, 500).round().to_sparse()
 -     fam.add_relation_type('Foo-Bar', adj_mat)
 -     return data
 - 
 - 
 - def _some_data_big():
 -     data = Data()
 -     data.add_node_type('Foo', 2000)
 -     data.add_node_type('Bar', 2100)
 -     fam = data.add_relation_family('Foo-Bar', 0, 1, True)
 -     adj_mat = torch.rand(2000, 2100).round().to_sparse()
 -     fam.add_relation_type('Foo-Bar', adj_mat)
 -     return data
 - 
 - 
 - def test_data_batcher_01():
 -     data = _some_data()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 - 
 - 
 - def test_data_batcher_02():
 -     data = _some_data()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 -     for batch_d in batcher:
 -         pass
 - 
 - 
 - def test_data_batcher_03():
 -     data = _some_data()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 -     for batch_d in batcher:
 -         edges_list = []
 -         for fam in batch_d.relation_families:
 -             for rel in fam.relation_types:
 -                 for edge_type in ['edges_pos', 'edges_neg',
 -                     'edges_back_pos', 'edges_back_neg']:
 -                     for part_type in ['train', 'val', 'test']:
 -                         edges = getattr(getattr(rel, edge_type), part_type)
 -                         edges_list.append(edges)
 -         assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1
 - 
 - 
 - def test_data_batcher_04():
 -     data = _some_data()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 -     edges_list = []
 -     for batch_d in batcher:
 -         for fam in batch_d.relation_families:
 -             for rel in fam.relation_types:
 -                 for edge_type in ['edges_pos', 'edges_neg',
 -                     'edges_back_pos', 'edges_back_neg']:
 -                     for part_type in ['train', 'val', 'test']:
 -                         edges = getattr(getattr(rel, edge_type), part_type)
 -                         edges_list.append(edges)
 -     assert sum([ len(edges) for edges in edges_list ]) == \
 -         torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2
 - 
 - 
 - def test_data_batcher_05():
 -     data = _some_data()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 -     for batch_d in batcher:
 -         edges_list = []
 -         for fam in batch_d.relation_families:
 -             for rel in fam.relation_types:
 -                 for edge_type in ['edges_pos', 'edges_neg',
 -                     'edges_back_pos', 'edges_back_neg']:
 -                     for part_type in ['train', 'val', 'test']:
 -                         edges = getattr(getattr(rel, edge_type), part_type)
 -                         edges_list.append(edges)
 -         assert all([ len(edges) <= 512 for edges in edges_list ])
 -         assert not all([ len(edges) == 0 for edges in edges_list ])
 -         print(sum(map(len, edges_list)))
 - 
 - 
 - def test_batch_decode_01():
 -     data = _some_data()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 -     ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
 -     in_repr = [ torch.rand(100, 32),
 -         torch.rand(500, 32) ]
 -     dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
 -     t = time.time()
 -     for batched_data in batcher:
 -         ptr.batched_data = batched_data
 -         _ = dec_layer(in_repr)
 -     print('Elapsed:', time.time() - t)
 - 
 - 
 - def test_batch_decode_02():
 -     data = _some_data_big()
 -     prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
 -     batcher = DataBatcher(prep_d, 512)
 -     ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
 -     in_repr = [ torch.rand(2000, 32),
 -         torch.rand(2100, 32) ]
 -     dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
 -     t = time.time()
 -     for batched_data in batcher:
 -         ptr.batched_data = batched_data
 -         _ = dec_layer(in_repr)
 -     print('Elapsed:', time.time() - t)
 
 
  |