IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Use torch.nn.ModuleList in convlayer.

master
Stanislaw Adaszewski 3 years ago
parent
commit
55a5f3a2d2
2 changed files with 68 additions and 9 deletions
  1. +14
    -7
      src/icosagon/convlayer.py
  2. +54
    -2
      tests/icosagon/test_convlayer.py

+ 14
- 7
src/icosagon/convlayer.py View File

@@ -9,10 +9,16 @@ from collections import defaultdict
from dataclasses import dataclass
@dataclass
class Convolutions(object):
class Convolutions(torch.nn.Module):
node_type_column: int
convolutions: List[DropoutGraphConvActivation]
convolutions: torch.nn.ModuleList # [DropoutGraphConvActivation]
def __init__(self, node_type_column: int,
convolutions: torch.nn.ModuleList, **kwargs):
super().__init__(**kwargs)
self.node_type_column = node_type_column
self.convolutions = convolutions
class DecagonLayer(torch.nn.Module):
@@ -51,7 +57,7 @@ class DecagonLayer(torch.nn.Module):
self.build()
def build_fam_one_node_type(self, fam):
convolutions = []
convolutions = torch.nn.ModuleList()
for r in fam.relation_types:
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
@@ -63,8 +69,8 @@ class DecagonLayer(torch.nn.Module):
Convolutions(fam.node_type_column, convolutions))
def build_fam_two_node_types(self, fam) -> None:
convolutions_row = []
convolutions_column = []
convolutions_row = torch.nn.ModuleList()
convolutions_column = torch.nn.ModuleList()
for r in fam.relation_types:
if r.adjacency_matrix is not None:
@@ -92,7 +98,8 @@ class DecagonLayer(torch.nn.Module):
self.build_fam_two_node_types(fam)
def build(self):
self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
self.next_layer_repr = torch.nn.ModuleList([
torch.nn.ModuleList() for _ in range(len(self.data.node_types)) ])
for fam in self.data.relation_families:
self.build_family(fam)


+ 54
- 2
tests/icosagon/test_convlayer.py View File

@@ -77,10 +77,10 @@ def test_decagon_layer_03():
for i in range(2):
assert len(d_layer.next_layer_repr[i]) == 2
assert isinstance(d_layer.next_layer_repr[i], list)
assert isinstance(d_layer.next_layer_repr[i], torch.nn.ModuleList)
assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
assert isinstance(d_layer.next_layer_repr[i][0].convolutions, list)
assert isinstance(d_layer.next_layer_repr[i][0].convolutions, torch.nn.ModuleList)
assert all([
isinstance(dgca, DropoutGraphConvActivation) \
for dgca in d_layer.next_layer_repr[i][0].convolutions
@@ -209,7 +209,28 @@ class Dummy4(torch.nn.Module):
self.dummy_1 = torch.nn.ModuleList([ Dummy1() ])
class Dummy5(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dummy_1 = [ torch.nn.ModuleList([ Dummy1() ]) ]
class Dummy6(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList([ Dummy1() ]) ])
class Dummy7(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList() ])
self.dummy_1[0].append(Dummy1())
def test_module_nesting_01():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA support on this host')
device = torch.device('cuda:0')
dummy_2 = Dummy2()
dummy_2 = dummy_2.to(device)
@@ -217,6 +238,8 @@ def test_module_nesting_01():
def test_module_nesting_02():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA support on this host')
device = torch.device('cuda:0')
dummy_3 = Dummy3()
dummy_3 = dummy_3.to(device)
@@ -224,7 +247,36 @@ def test_module_nesting_02():
def test_module_nesting_03():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA support on this host')
device = torch.device('cuda:0')
dummy_4 = Dummy4()
dummy_4 = dummy_4.to(device)
assert dummy_4.dummy_1[0].whatever.device == device
def test_module_nesting_04():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA support on this host')
device = torch.device('cuda:0')
dummy_5 = Dummy5()
dummy_5 = dummy_5.to(device)
assert dummy_5.dummy_1[0][0].whatever.device != device
def test_module_nesting_05():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA support on this host')
device = torch.device('cuda:0')
dummy_6 = Dummy6()
dummy_6 = dummy_6.to(device)
assert dummy_6.dummy_1[0][0].whatever.device == device
def test_module_nesting_06():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA support on this host')
device = torch.device('cuda:0')
dummy_7 = Dummy7()
dummy_7 = dummy_7.to(device)
assert dummy_7.dummy_1[0][0].whatever.device == device

Loading…
Cancel
Save