|
|
@@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module): |
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.global_interaction = init_glorot(input_dim, input_dim)
|
|
|
|
self.local_variation = [
|
|
|
|
torch.flatten(init_glorot(input_dim, 1)) \
|
|
|
|
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
|
|
|
|
self.local_variation = torch.nn.ParameterList([
|
|
|
|
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
|
|
|
|
for _ in range(num_relation_types)
|
|
|
|
]
|
|
|
|
])
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, relation_index):
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
@@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module): |
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.relation = [
|
|
|
|
torch.flatten(init_glorot(input_dim, 1)) \
|
|
|
|
self.relation = torch.nn.ParameterList([
|
|
|
|
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
|
|
|
|
for _ in range(num_relation_types)
|
|
|
|
]
|
|
|
|
])
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, relation_index):
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
@@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module): |
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.relation = [
|
|
|
|
init_glorot(input_dim, input_dim) \
|
|
|
|
self.relation = torch.nn.ParameterList([
|
|
|
|
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
|
|
|
|
for _ in range(num_relation_types)
|
|
|
|
]
|
|
|
|
])
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, relation_index):
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|