From cd4c34ea7d39c0f3969c7894545f73cc9d178961 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 17 Jul 2020 13:20:06 +0200 Subject: [PATCH] Fix for non-leaf .grad warning. --- src/icosagon/input.py | 6 +++--- tests/icosagon/test_trainloop.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/icosagon/input.py b/src/icosagon/input.py index 8c686b1..3bf5824 100644 --- a/src/icosagon/input.py +++ b/src/icosagon/input.py @@ -60,11 +60,11 @@ class OneHotInputLayer(torch.nn.Module): self.build() def build(self) -> None: - self.node_reps = [] + self.node_reps = torch.nn.ParameterList() for i, nt in enumerate(self.data.node_types): reps = torch.eye(nt.count).to_sparse() - reps = torch.nn.Parameter(reps) - self.register_parameter('node_reps[%d]' % i, reps) + reps = torch.nn.Parameter(reps, requires_grad=False) + # self.register_parameter('node_reps[%d]' % i, reps) self.node_reps.append(reps) def forward(self, x) -> List[torch.nn.Parameter]: diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py index c052dba..ca77edd 100644 --- a/tests/icosagon/test_trainloop.py +++ b/tests/icosagon/test_trainloop.py @@ -37,10 +37,16 @@ def test_train_loop_02(): m = Model(prep_d) + for prm in m.parameters(): + print(prm.shape, prm.is_leaf, prm.requires_grad) + loop = TrainLoop(m) loop.run_epoch() + for prm in m.parameters(): + print(prm.shape, prm.is_leaf, prm.requires_grad) + def test_train_loop_03(): # pdb.set_trace()