from icosagon.data import Data from icosagon.trainprep import prepare_training, \ TrainValTest from icosagon.model import Model from icosagon.trainloop import TrainLoop import torch def test_train_loop_01(): d = Data() d.add_node_type('Dummy', 10) fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) m = Model(prep_d) loop = TrainLoop(m) assert loop.model == m assert loop.lr == 0.001 assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits assert loop.batch_size == 100 def test_train_loop_02(): d = Data() d.add_node_type('Dummy', 10) fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round()) prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) m = Model(prep_d) loop = TrainLoop(m) loop.run_epoch()