|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- from icosagon.data import Data
- from icosagon.trainprep import prepare_training, \
- TrainValTest
- from icosagon.model import Model
- from icosagon.trainloop import TrainLoop
- import torch
- import pytest
- import pdb
- import time
-
-
- def test_train_loop_01():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- m = Model(prep_d)
-
- loop = TrainLoop(m)
-
- assert loop.model == m
- assert loop.lr == 0.001
- assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
- assert loop.batch_size == 100
-
-
- def test_train_loop_02():
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
-
- m = Model(prep_d)
-
- loop = TrainLoop(m)
-
- loop.run_epoch()
-
-
- def test_train_loop_03():
- # pdb.set_trace()
- if torch.cuda.device_count() == 0:
- pytest.skip('CUDA required for this test')
-
- adj_mat = torch.rand(10, 10).round()
- dev = torch.device('cuda:0')
- adj_mat = adj_mat.to(dev)
-
- d = Data()
- d.add_node_type('Dummy', 10)
- fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
- fam.add_relation_type('Dummy Rel', adj_mat)
-
- prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
- # pdb.set_trace()
-
- m = Model(prep_d)
- m = m.to(dev)
-
- print(list(m.parameters()))
-
- for prm in m.parameters():
- assert prm.device == dev
-
- loop = TrainLoop(m)
-
- loop.run_epoch()
-
-
- def test_timing_01():
- adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
- rep = torch.eye(2000).requires_grad_(True)
- t = time.time()
- for _ in range(1300):
- _ = torch.sparse.mm(adj_mat, rep)
- print('Elapsed:', time.time() - t)
|