From fe6c6598f8466e6074d3ccc0bc218c04b32a3a69 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 12 Jun 2020 21:09:04 +0200 Subject: [PATCH] Add module nesting tests. --- tests/icosagon/test_convlayer.py | 45 ++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py index 145ad95..1b6fdba 100644 --- a/tests/icosagon/test_convlayer.py +++ b/tests/icosagon/test_convlayer.py @@ -183,3 +183,48 @@ def test_decagon_layer_05(): assert len(out_d_layer) == 1 assert torch.all(out_d_layer[0] == out_multi_dgca) + + +class Dummy1(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.whatever = torch.nn.Parameter(torch.rand((10, 10))) + + +class Dummy2(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dummy_1 = Dummy1() + + +class Dummy3(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dummy_1 = [ Dummy1() ] + + +class Dummy4(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dummy_1 = torch.nn.ModuleList([ Dummy1() ]) + + +def test_module_nesting_01(): + device = torch.device('cuda:0') + dummy_2 = Dummy2() + dummy_2 = dummy_2.to(device) + assert dummy_2.dummy_1.whatever.device == device + + +def test_module_nesting_02(): + device = torch.device('cuda:0') + dummy_3 = Dummy3() + dummy_3 = dummy_3.to(device) + assert dummy_3.dummy_1[0].whatever.device != device + + +def test_module_nesting_03(): + device = torch.device('cuda:0') + dummy_4 = Dummy4() + dummy_4 = dummy_4.to(device) + assert dummy_4.dummy_1[0].whatever.device == device