| @@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module): | |||||
| self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
| self.activation = activation | self.activation = activation | ||||
| self.global_interaction = init_glorot(input_dim, input_dim) | |||||
| self.local_variation = [ | |||||
| torch.flatten(init_glorot(input_dim, 1)) \ | |||||
| self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||||
| self.local_variation = torch.nn.ParameterList([ | |||||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
| for _ in range(num_relation_types) | for _ in range(num_relation_types) | ||||
| ] | |||||
| ]) | |||||
| def forward(self, inputs_row, inputs_col, relation_index): | def forward(self, inputs_row, inputs_col, relation_index): | ||||
| inputs_row = dropout(inputs_row, self.keep_prob) | inputs_row = dropout(inputs_row, self.keep_prob) | ||||
| @@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module): | |||||
| self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
| self.activation = activation | self.activation = activation | ||||
| self.relation = [ | |||||
| torch.flatten(init_glorot(input_dim, 1)) \ | |||||
| self.relation = torch.nn.ParameterList([ | |||||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
| for _ in range(num_relation_types) | for _ in range(num_relation_types) | ||||
| ] | |||||
| ]) | |||||
| def forward(self, inputs_row, inputs_col, relation_index): | def forward(self, inputs_row, inputs_col, relation_index): | ||||
| inputs_row = dropout(inputs_row, self.keep_prob) | inputs_row = dropout(inputs_row, self.keep_prob) | ||||
| @@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module): | |||||
| self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
| self.activation = activation | self.activation = activation | ||||
| self.relation = [ | |||||
| init_glorot(input_dim, input_dim) \ | |||||
| self.relation = torch.nn.ParameterList([ | |||||
| torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||||
| for _ in range(num_relation_types) | for _ in range(num_relation_types) | ||||
| ] | |||||
| ]) | |||||
| def forward(self, inputs_row, inputs_col, relation_index): | def forward(self, inputs_row, inputs_col, relation_index): | ||||
| inputs_row = dropout(inputs_row, self.keep_prob) | inputs_row = dropout(inputs_row, self.keep_prob) | ||||