| @@ -183,3 +183,48 @@ def test_decagon_layer_05(): | |||||
| assert len(out_d_layer) == 1 | assert len(out_d_layer) == 1 | ||||
| assert torch.all(out_d_layer[0] == out_multi_dgca) | assert torch.all(out_d_layer[0] == out_multi_dgca) | ||||
| class Dummy1(torch.nn.Module): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.whatever = torch.nn.Parameter(torch.rand((10, 10))) | |||||
| class Dummy2(torch.nn.Module): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.dummy_1 = Dummy1() | |||||
| class Dummy3(torch.nn.Module): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.dummy_1 = [ Dummy1() ] | |||||
| class Dummy4(torch.nn.Module): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.dummy_1 = torch.nn.ModuleList([ Dummy1() ]) | |||||
| def test_module_nesting_01(): | |||||
| device = torch.device('cuda:0') | |||||
| dummy_2 = Dummy2() | |||||
| dummy_2 = dummy_2.to(device) | |||||
| assert dummy_2.dummy_1.whatever.device == device | |||||
| def test_module_nesting_02(): | |||||
| device = torch.device('cuda:0') | |||||
| dummy_3 = Dummy3() | |||||
| dummy_3 = dummy_3.to(device) | |||||
| assert dummy_3.dummy_1[0].whatever.device != device | |||||
| def test_module_nesting_03(): | |||||
| device = torch.device('cuda:0') | |||||
| dummy_4 = Dummy4() | |||||
| dummy_4 = dummy_4.to(device) | |||||
| assert dummy_4.dummy_1[0].whatever.device == device | |||||