diff --git a/src/icosagon/decode.py b/src/icosagon/decode.py index b69ec97..00df8b2 100644 --- a/src/icosagon/decode.py +++ b/src/icosagon/decode.py @@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module): self.keep_prob = keep_prob self.activation = activation - self.global_interaction = init_glorot(input_dim, input_dim) - self.local_variation = [ - torch.flatten(init_glorot(input_dim, 1)) \ + self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) + self.local_variation = torch.nn.ParameterList([ + torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ for _ in range(num_relation_types) - ] + ]) def forward(self, inputs_row, inputs_col, relation_index): inputs_row = dropout(inputs_row, self.keep_prob) @@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module): self.keep_prob = keep_prob self.activation = activation - self.relation = [ - torch.flatten(init_glorot(input_dim, 1)) \ + self.relation = torch.nn.ParameterList([ + torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ for _ in range(num_relation_types) - ] + ]) def forward(self, inputs_row, inputs_col, relation_index): inputs_row = dropout(inputs_row, self.keep_prob) @@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module): self.keep_prob = keep_prob self.activation = activation - self.relation = [ - init_glorot(input_dim, input_dim) \ + self.relation = torch.nn.ParameterList([ + torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ for _ in range(num_relation_types) - ] + ]) def forward(self, inputs_row, inputs_col, relation_index): inputs_row = dropout(inputs_row, self.keep_prob)