From cfb3d308a0346c3cb75f8d893a877bb17d3c8a87 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 17 Jun 2020 12:28:46 +0200 Subject: [PATCH] Add test_model_03(). --- src/icosagon/convolve.py | 2 +- src/icosagon/declayer.py | 2 +- src/icosagon/trainprep.py | 2 +- tests/icosagon/test_model.py | 6 +++++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/icosagon/convolve.py b/src/icosagon/convolve.py index 118a58e..364f61a 100644 --- a/src/icosagon/convolve.py +++ b/src/icosagon/convolve.py @@ -16,7 +16,7 @@ class GraphConv(torch.nn.Module): super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels - self.weight = init_glorot(in_channels, out_channels) + self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels)) self.adjacency_matrix = adjacency_matrix diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index db78f11..8f6bff4 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -68,7 +68,7 @@ class DecodeLayer(torch.nn.Module): self.build() def build(self) -> None: - self.decoders = [] + self.decoders = torch.nn.ModuleList() for fam in self.data.relation_families: dec = fam.decoder_class(self.input_dim, len(fam.relation_types), self.keep_prob, self.activation) diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py index 6e8211b..b7074e1 100644 --- a/src/icosagon/trainprep.py +++ b/src/icosagon/trainprep.py @@ -111,7 +111,7 @@ def prepare_adj_mat(adj_mat: torch.Tensor, edges_neg = train_val_test_split_edges(edges_neg, ratios) adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), - values=torch.ones(len(edges_pos.train), shape=adj_mat.shape, dtype=adj_mat.dtype)) + values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype) return adj_mat_train, edges_pos, edges_neg diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py index c645d32..5ee16ef 100644 --- a/tests/icosagon/test_model.py +++ b/tests/icosagon/test_model.py @@ -81,4 +81,8 @@ def test_model_03(): assert isinstance(state_dict, dict) # print(state_dict['param_groups']) # print(list(m.seq.parameters())) - print(list(m.seq[1].parameters())) + assert len(list(m.seq[0].parameters())) == 1 + assert len(list(m.seq[1].parameters())) == 1 + assert len(list(m.seq[2].parameters())) == 1 + assert len(list(m.seq[3].parameters())) == 2 + # print(list(m.seq[1].parameters()))