diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index cdaced6..e98b55e 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -9,10 +9,16 @@ from collections import defaultdict from dataclasses import dataclass -@dataclass -class Convolutions(object): +class Convolutions(torch.nn.Module): node_type_column: int - convolutions: List[DropoutGraphConvActivation] + convolutions: torch.nn.ModuleList # [DropoutGraphConvActivation] + + def __init__(self, node_type_column: int, + convolutions: torch.nn.ModuleList, **kwargs): + + super().__init__(**kwargs) + self.node_type_column = node_type_column + self.convolutions = convolutions class DecagonLayer(torch.nn.Module): @@ -51,7 +57,7 @@ class DecagonLayer(torch.nn.Module): self.build() def build_fam_one_node_type(self, fam): - convolutions = [] + convolutions = torch.nn.ModuleList() for r in fam.relation_types: conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], @@ -63,8 +69,8 @@ class DecagonLayer(torch.nn.Module): Convolutions(fam.node_type_column, convolutions)) def build_fam_two_node_types(self, fam) -> None: - convolutions_row = [] - convolutions_column = [] + convolutions_row = torch.nn.ModuleList() + convolutions_column = torch.nn.ModuleList() for r in fam.relation_types: if r.adjacency_matrix is not None: @@ -92,7 +98,8 @@ class DecagonLayer(torch.nn.Module): self.build_fam_two_node_types(fam) def build(self): - self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] + self.next_layer_repr = torch.nn.ModuleList([ + torch.nn.ModuleList() for _ in range(len(self.data.node_types)) ]) for fam in self.data.relation_families: self.build_family(fam) diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py index 1b6fdba..96fb225 100644 --- a/tests/icosagon/test_convlayer.py +++ b/tests/icosagon/test_convlayer.py @@ -77,10 +77,10 @@ def test_decagon_layer_03(): for i in range(2): assert len(d_layer.next_layer_repr[i]) == 2 - assert isinstance(d_layer.next_layer_repr[i], list) + assert isinstance(d_layer.next_layer_repr[i], torch.nn.ModuleList) assert isinstance(d_layer.next_layer_repr[i][0], Convolutions) assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int) - assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list) + assert isinstance(d_layer.next_layer_repr[i][0].convolutions, torch.nn.ModuleList) assert all([ isinstance(dgca, DropoutGraphConvActivation) \ for dgca in d_layer.next_layer_repr[i][0].convolutions @@ -209,7 +209,28 @@ class Dummy4(torch.nn.Module): self.dummy_1 = torch.nn.ModuleList([ Dummy1() ]) +class Dummy5(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dummy_1 = [ torch.nn.ModuleList([ Dummy1() ]) ] + + +class Dummy6(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList([ Dummy1() ]) ]) + + +class Dummy7(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList() ]) + self.dummy_1[0].append(Dummy1()) + + def test_module_nesting_01(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA support on this host') device = torch.device('cuda:0') dummy_2 = Dummy2() dummy_2 = dummy_2.to(device) @@ -217,6 +238,8 @@ def test_module_nesting_01(): def test_module_nesting_02(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA support on this host') device = torch.device('cuda:0') dummy_3 = Dummy3() dummy_3 = dummy_3.to(device) @@ -224,7 +247,36 @@ def test_module_nesting_02(): def test_module_nesting_03(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA support on this host') device = torch.device('cuda:0') dummy_4 = Dummy4() dummy_4 = dummy_4.to(device) assert dummy_4.dummy_1[0].whatever.device == device + + +def test_module_nesting_04(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA support on this host') + device = torch.device('cuda:0') + dummy_5 = Dummy5() + dummy_5 = dummy_5.to(device) + assert dummy_5.dummy_1[0][0].whatever.device != device + + +def test_module_nesting_05(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA support on this host') + device = torch.device('cuda:0') + dummy_6 = Dummy6() + dummy_6 = dummy_6.to(device) + assert dummy_6.dummy_1[0][0].whatever.device == device + + +def test_module_nesting_06(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA support on this host') + device = torch.device('cuda:0') + dummy_7 = Dummy7() + dummy_7 = dummy_7.to(device) + assert dummy_7.dummy_1[0][0].whatever.device == device