IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add module nesting tests.

master
Stanislaw Adaszewski 3 years ago
parent
commit
fe6c6598f8
1 changed files with 45 additions and 0 deletions
  1. +45
    -0
      tests/icosagon/test_convlayer.py

+ 45
- 0
tests/icosagon/test_convlayer.py View File

@@ -183,3 +183,48 @@ def test_decagon_layer_05():
assert len(out_d_layer) == 1
assert torch.all(out_d_layer[0] == out_multi_dgca)
class Dummy1(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.whatever = torch.nn.Parameter(torch.rand((10, 10)))
class Dummy2(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dummy_1 = Dummy1()
class Dummy3(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dummy_1 = [ Dummy1() ]
class Dummy4(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dummy_1 = torch.nn.ModuleList([ Dummy1() ])
def test_module_nesting_01():
device = torch.device('cuda:0')
dummy_2 = Dummy2()
dummy_2 = dummy_2.to(device)
assert dummy_2.dummy_1.whatever.device == device
def test_module_nesting_02():
device = torch.device('cuda:0')
dummy_3 = Dummy3()
dummy_3 = dummy_3.to(device)
assert dummy_3.dummy_1[0].whatever.device != device
def test_module_nesting_03():
device = torch.device('cuda:0')
dummy_4 = Dummy4()
dummy_4 = dummy_4.to(device)
assert dummy_4.dummy_1[0].whatever.device == device

Loading…
Cancel
Save