|
|
@@ -1,4 +1,5 @@ |
|
|
|
from icosagon.data import Data
|
|
|
|
from icosagon.data import Data, \
|
|
|
|
_equal
|
|
|
|
from icosagon.trainprep import prepare_training, \
|
|
|
|
TrainValTest
|
|
|
|
from icosagon.model import Model
|
|
|
@@ -78,6 +79,56 @@ def test_train_loop_03(): |
|
|
|
loop.run_epoch()
|
|
|
|
|
|
|
|
|
|
|
|
def test_train_loop_04():
|
|
|
|
adj_mat = torch.rand(10, 10).round()
|
|
|
|
|
|
|
|
d = Data()
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
fam.add_relation_type('Dummy Rel', adj_mat)
|
|
|
|
|
|
|
|
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
|
|
|
|
|
|
|
|
m = Model(prep_d)
|
|
|
|
|
|
|
|
old_values = []
|
|
|
|
for prm in m.parameters():
|
|
|
|
old_values.append(prm.clone().detach())
|
|
|
|
|
|
|
|
loop = TrainLoop(m)
|
|
|
|
|
|
|
|
loop.run_epoch()
|
|
|
|
|
|
|
|
for i, prm in enumerate(m.parameters()):
|
|
|
|
assert not prm.requires_grad or \
|
|
|
|
not torch.all(_equal(prm, old_values[i]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_train_loop_05():
|
|
|
|
adj_mat = torch.rand(10, 10).round().to_sparse()
|
|
|
|
|
|
|
|
d = Data()
|
|
|
|
d.add_node_type('Dummy', 10)
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
fam.add_relation_type('Dummy Rel', adj_mat)
|
|
|
|
|
|
|
|
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
|
|
|
|
|
|
|
|
m = Model(prep_d)
|
|
|
|
|
|
|
|
old_values = []
|
|
|
|
for prm in m.parameters():
|
|
|
|
old_values.append(prm.clone().detach())
|
|
|
|
|
|
|
|
loop = TrainLoop(m)
|
|
|
|
|
|
|
|
loop.run_epoch()
|
|
|
|
|
|
|
|
for i, prm in enumerate(m.parameters()):
|
|
|
|
assert not prm.requires_grad or \
|
|
|
|
not torch.all(_equal(prm, old_values[i]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_timing_01():
|
|
|
|
adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
|
|
|
|
rep = torch.eye(2000).requires_grad_(True)
|
|
|
|