IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Kaynağa Gözat

Fix for non-leaf .grad warning.

master
Stanislaw Adaszewski 4 yıl önce
ebeveyn
işleme
cd4c34ea7d
2 değiştirilmiş dosya ile 9 ekleme ve 3 silme
  1. +3
    -3
      src/icosagon/input.py
  2. +6
    -0
      tests/icosagon/test_trainloop.py

+ 3
- 3
src/icosagon/input.py Dosyayı Görüntüle

@@ -60,11 +60,11 @@ class OneHotInputLayer(torch.nn.Module):
self.build()
def build(self) -> None:
self.node_reps = []
self.node_reps = torch.nn.ParameterList()
for i, nt in enumerate(self.data.node_types):
reps = torch.eye(nt.count).to_sparse()
reps = torch.nn.Parameter(reps)
self.register_parameter('node_reps[%d]' % i, reps)
reps = torch.nn.Parameter(reps, requires_grad=False)
# self.register_parameter('node_reps[%d]' % i, reps)
self.node_reps.append(reps)
def forward(self, x) -> List[torch.nn.Parameter]:


+ 6
- 0
tests/icosagon/test_trainloop.py Dosyayı Görüntüle

@@ -37,10 +37,16 @@ def test_train_loop_02():
m = Model(prep_d)
for prm in m.parameters():
print(prm.shape, prm.is_leaf, prm.requires_grad)
loop = TrainLoop(m)
loop.run_epoch()
for prm in m.parameters():
print(prm.shape, prm.is_leaf, prm.requires_grad)
def test_train_loop_03():
# pdb.set_trace()


Yükleniyor…
İptal
Kaydet