IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Преглед на файлове

Remove unnecessary .detach() call in dropout_dense().

master
Stanislaw Adaszewski преди 4 години
родител
ревизия
fc8f9726af
променени са 2 файла, в които са добавени 54 реда и са изтрити 2 реда
  1. +2
    -1
      src/icosagon/dropout.py
  2. +52
    -1
      tests/icosagon/test_trainloop.py

+ 2
- 1
src/icosagon/dropout.py Целия файл

@@ -24,7 +24,8 @@ def dropout_sparse(x, keep_prob):
def dropout_dense(x, keep_prob):
x = x.clone().detach()
# print('dropout_dense()')
x = x.clone()
i = torch.nonzero(x)
n = keep_prob + torch.rand(len(i))


+ 52
- 1
tests/icosagon/test_trainloop.py Целия файл

@@ -1,4 +1,5 @@
from icosagon.data import Data
from icosagon.data import Data, \
_equal
from icosagon.trainprep import prepare_training, \
TrainValTest
from icosagon.model import Model
@@ -78,6 +79,56 @@ def test_train_loop_03():
loop.run_epoch()
def test_train_loop_04():
adj_mat = torch.rand(10, 10).round()
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', adj_mat)
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
m = Model(prep_d)
old_values = []
for prm in m.parameters():
old_values.append(prm.clone().detach())
loop = TrainLoop(m)
loop.run_epoch()
for i, prm in enumerate(m.parameters()):
assert not prm.requires_grad or \
not torch.all(_equal(prm, old_values[i]))
def test_train_loop_05():
adj_mat = torch.rand(10, 10).round().to_sparse()
d = Data()
d.add_node_type('Dummy', 10)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Rel', adj_mat)
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
m = Model(prep_d)
old_values = []
for prm in m.parameters():
old_values.append(prm.clone().detach())
loop = TrainLoop(m)
loop.run_epoch()
for i, prm in enumerate(m.parameters()):
assert not prm.requires_grad or \
not torch.all(_equal(prm, old_values[i]))
def test_timing_01():
adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
rep = torch.eye(2000).requires_grad_(True)


Loading…
Отказ
Запис