|
|
@@ -11,13 +11,13 @@ from .dropout import dropout |
|
|
|
|
|
|
|
class DEDICOMDecoder(torch.nn.Module):
|
|
|
|
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
|
|
|
|
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
|
|
|
|
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.num_relation_types = num_relation_types
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.global_interaction = init_glorot(input_dim, input_dim)
|
|
|
@@ -29,8 +29,8 @@ class DEDICOMDecoder(torch.nn.Module): |
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
outputs = []
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
|
|
|
|
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
relation = torch.diag(self.local_variation[k])
|
|
|
|
|
|
|
@@ -46,13 +46,13 @@ class DEDICOMDecoder(torch.nn.Module): |
|
|
|
|
|
|
|
class DistMultDecoder(torch.nn.Module):
|
|
|
|
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
|
|
|
|
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
|
|
|
|
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.num_relation_types = num_relation_types
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.relation = [
|
|
|
@@ -63,8 +63,8 @@ class DistMultDecoder(torch.nn.Module): |
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
outputs = []
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
|
|
|
|
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
relation = torch.diag(self.relation[k])
|
|
|
|
|
|
|
@@ -78,13 +78,13 @@ class DistMultDecoder(torch.nn.Module): |
|
|
|
|
|
|
|
class BilinearDecoder(torch.nn.Module):
|
|
|
|
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
|
|
|
|
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
|
|
|
|
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.num_relation_types = num_relation_types
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.relation = [
|
|
|
@@ -95,8 +95,8 @@ class BilinearDecoder(torch.nn.Module): |
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
outputs = []
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
|
|
|
|
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
intermediate_product = torch.mm(inputs_row, self.relation[k])
|
|
|
|
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
|
|
|
@@ -108,21 +108,21 @@ class BilinearDecoder(torch.nn.Module): |
|
|
|
|
|
|
|
class InnerProductDecoder(torch.nn.Module):
|
|
|
|
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
|
|
|
|
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
|
|
|
|
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.num_relation_types = num_relation_types
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
outputs = []
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
|
|
|
|
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
|
|
|
|
inputs_row = dropout(inputs_row, self.keep_prob)
|
|
|
|
inputs_col = dropout(inputs_col, self.keep_prob)
|
|
|
|
|
|
|
|
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
|
|
|
|
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
|
|
|
|