IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
소스 검색

Fix for non-leaf .grad warning.

master
Stanislaw Adaszewski 4 년 전
부모
커밋
cd4c34ea7d
2개의 변경된 파일9개의 추가작업 그리고 3개의 파일을 삭제
  1. +3
    -3
      src/icosagon/input.py
  2. +6
    -0
      tests/icosagon/test_trainloop.py

+ 3
- 3
src/icosagon/input.py 파일 보기

@@ -60,11 +60,11 @@ class OneHotInputLayer(torch.nn.Module):
self.build()
def build(self) -> None:
self.node_reps = []
self.node_reps = torch.nn.ParameterList()
for i, nt in enumerate(self.data.node_types):
reps = torch.eye(nt.count).to_sparse()
reps = torch.nn.Parameter(reps)
self.register_parameter('node_reps[%d]' % i, reps)
reps = torch.nn.Parameter(reps, requires_grad=False)
# self.register_parameter('node_reps[%d]' % i, reps)
self.node_reps.append(reps)
def forward(self, x) -> List[torch.nn.Parameter]:


+ 6
- 0
tests/icosagon/test_trainloop.py 파일 보기

@@ -37,10 +37,16 @@ def test_train_loop_02():
m = Model(prep_d)
for prm in m.parameters():
print(prm.shape, prm.is_leaf, prm.requires_grad)
loop = TrainLoop(m)
loop.run_epoch()
for prm in m.parameters():
print(prm.shape, prm.is_leaf, prm.requires_grad)
def test_train_loop_03():
# pdb.set_trace()


불러오는 중...
취소
저장