|
|
@@ -0,0 +1,37 @@ |
|
|
|
import torch
|
|
|
|
from .weights import init_glorot
|
|
|
|
from .dropout import dropout
|
|
|
|
|
|
|
|
|
|
|
|
class DEDICOMDecoder(torch.nn.Module):
|
|
|
|
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
|
|
|
|
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
|
|
|
|
activation=torch.sigmoid, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.num_relation_types = num_relation_types
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
|
|
|
|
self.global_interaction = init_glorot(input_dim, input_dim)
|
|
|
|
self.local_variation = [
|
|
|
|
torch.flatten(init_glorot(input_dim, 1)) \
|
|
|
|
for _ in range(num_relation_types)
|
|
|
|
]
|
|
|
|
|
|
|
|
def forward(self, inputs_row, inputs_col):
|
|
|
|
outputs = []
|
|
|
|
for k in range(self.num_relation_types):
|
|
|
|
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
|
|
|
|
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
|
|
|
|
|
|
|
|
relation = torch.diag(self.local_variation[k])
|
|
|
|
|
|
|
|
product1 = torch.mm(inputs_row, relation)
|
|
|
|
product2 = torch.mm(product1, self.global_interaction)
|
|
|
|
product3 = torch.mm(product2, relation)
|
|
|
|
rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
|
|
|
|
outputs.append(self.activation(rec))
|
|
|
|
return outputs
|