|
@@ -26,22 +26,20 @@ class DEDICOMDecoder(torch.nn.Module): |
|
|
for _ in range(num_relation_types)
|
|
|
for _ in range(num_relation_types)
|
|
|
]
|
|
|
]
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, relation_index):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
relation = torch.diag(self.local_variation[k])
|
|
|
|
|
|
|
|
|
relation = torch.diag(self.local_variation[relation_index])
|
|
|
|
|
|
|
|
|
product1 = torch.mm(inputs_row, relation)
|
|
|
|
|
|
product2 = torch.mm(product1, self.global_interaction)
|
|
|
|
|
|
product3 = torch.mm(product2, relation)
|
|
|
|
|
|
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
outputs.append(self.activation(rec))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
product1 = torch.mm(inputs_row, relation)
|
|
|
|
|
|
product2 = torch.mm(product1, self.global_interaction)
|
|
|
|
|
|
product3 = torch.mm(product2, relation)
|
|
|
|
|
|
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
|
|
|
|
|
|
return self.activation(rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistMultDecoder(torch.nn.Module):
|
|
|
class DistMultDecoder(torch.nn.Module):
|
|
@@ -60,20 +58,18 @@ class DistMultDecoder(torch.nn.Module): |
|
|
for _ in range(num_relation_types)
|
|
|
for _ in range(num_relation_types)
|
|
|
]
|
|
|
]
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, relation_index):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
|
|
|
relation = torch.diag(self.relation[relation_index])
|
|
|
|
|
|
|
|
|
relation = torch.diag(self.relation[k])
|
|
|
|
|
|
|
|
|
intermediate_product = torch.mm(inputs_row, relation)
|
|
|
|
|
|
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
|
|
|
intermediate_product = torch.mm(inputs_row, relation)
|
|
|
|
|
|
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
outputs.append(self.activation(rec))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
return self.activation(rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BilinearDecoder(torch.nn.Module):
|
|
|
class BilinearDecoder(torch.nn.Module):
|
|
@@ -92,18 +88,16 @@ class BilinearDecoder(torch.nn.Module): |
|
|
for _ in range(num_relation_types)
|
|
|
for _ in range(num_relation_types)
|
|
|
]
|
|
|
]
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, relation_index):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
intermediate_product = torch.mm(inputs_row, self.relation[k])
|
|
|
|
|
|
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
outputs.append(self.activation(rec))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
intermediate_product = torch.mm(inputs_row, self.relation[relation_index])
|
|
|
|
|
|
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
|
|
|
|
|
|
return self.activation(rec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InnerProductDecoder(torch.nn.Module):
|
|
|
class InnerProductDecoder(torch.nn.Module):
|
|
@@ -118,14 +112,12 @@ class InnerProductDecoder(torch.nn.Module): |
|
|
self.activation = activation
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col, _):
|
|
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
|
|
|
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
|
|
|
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
|
|
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|
|
|
rec = torch.flatten(rec)
|
|
|
|
|
|
outputs.append(self.activation(rec))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
return self.activation(rec)
|