diff --git a/src/icosagon/dropout.py b/src/icosagon/dropout.py index 74bdd57..63cfb58 100644 --- a/src/icosagon/dropout.py +++ b/src/icosagon/dropout.py @@ -24,7 +24,8 @@ def dropout_sparse(x, keep_prob): def dropout_dense(x, keep_prob): - x = x.clone().detach() + # print('dropout_dense()') + x = x.clone() i = torch.nonzero(x) n = keep_prob + torch.rand(len(i)) diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py index ca77edd..192cdf9 100644 --- a/tests/icosagon/test_trainloop.py +++ b/tests/icosagon/test_trainloop.py @@ -1,4 +1,5 @@ -from icosagon.data import Data +from icosagon.data import Data, \ + _equal from icosagon.trainprep import prepare_training, \ TrainValTest from icosagon.model import Model @@ -78,6 +79,56 @@ def test_train_loop_03(): loop.run_epoch() +def test_train_loop_04(): + adj_mat = torch.rand(10, 10).round() + + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', adj_mat) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + m = Model(prep_d) + + old_values = [] + for prm in m.parameters(): + old_values.append(prm.clone().detach()) + + loop = TrainLoop(m) + + loop.run_epoch() + + for i, prm in enumerate(m.parameters()): + assert not prm.requires_grad or \ + not torch.all(_equal(prm, old_values[i])) + + +def test_train_loop_05(): + adj_mat = torch.rand(10, 10).round().to_sparse() + + d = Data() + d.add_node_type('Dummy', 10) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Rel', adj_mat) + + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + + m = Model(prep_d) + + old_values = [] + for prm in m.parameters(): + old_values.append(prm.clone().detach()) + + loop = TrainLoop(m) + + loop.run_epoch() + + for i, prm in enumerate(m.parameters()): + assert not prm.requires_grad or \ + not torch.all(_equal(prm, old_values[i])) + + def test_timing_01(): adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() rep = torch.eye(2000).requires_grad_(True)