From 25e05cf1c2853672ddfc8eaf2a1a4c3123d88283 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 12 Jun 2020 21:49:44 +0200 Subject: [PATCH] Use torch.nn.Parameter(List) in decode. --- src/icosagon/decode.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/icosagon/decode.py b/src/icosagon/decode.py index b69ec97..00df8b2 100644 --- a/src/icosagon/decode.py +++ b/src/icosagon/decode.py @@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module): self.keep_prob = keep_prob self.activation = activation - self.global_interaction = init_glorot(input_dim, input_dim) - self.local_variation = [ - torch.flatten(init_glorot(input_dim, 1)) \ + self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) + self.local_variation = torch.nn.ParameterList([ + torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ for _ in range(num_relation_types) - ] + ]) def forward(self, inputs_row, inputs_col, relation_index): inputs_row = dropout(inputs_row, self.keep_prob) @@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module): self.keep_prob = keep_prob self.activation = activation - self.relation = [ - torch.flatten(init_glorot(input_dim, 1)) \ + self.relation = torch.nn.ParameterList([ + torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ for _ in range(num_relation_types) - ] + ]) def forward(self, inputs_row, inputs_col, relation_index): inputs_row = dropout(inputs_row, self.keep_prob) @@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module): self.keep_prob = keep_prob self.activation = activation - self.relation = [ - init_glorot(input_dim, input_dim) \ + self.relation = torch.nn.ParameterList([ + torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ for _ in range(num_relation_types) - ] + ]) def forward(self, inputs_row, inputs_col, relation_index): inputs_row = dropout(inputs_row, self.keep_prob)