@@ -0,0 +1,37 @@ | |||||
import torch | |||||
from .weights import init_glorot | |||||
from .dropout import dropout | |||||
class DEDICOMDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.drop_prob = drop_prob | |||||
self.activation = activation | |||||
self.global_interaction = init_glorot(input_dim, input_dim) | |||||
self.local_variation = [ | |||||
torch.flatten(init_glorot(input_dim, 1)) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
def forward(self, inputs_row, inputs_col): | |||||
outputs = [] | |||||
for k in range(self.num_relation_types): | |||||
inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
relation = torch.diag(self.local_variation[k]) | |||||
product1 = torch.mm(inputs_row, relation) | |||||
product2 = torch.mm(product1, self.global_interaction) | |||||
product3 = torch.mm(product2, relation) | |||||
rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) | |||||
outputs.append(self.activation(rec)) | |||||
return outputs |
@@ -0,0 +1,41 @@ | |||||
import decagon_pytorch.decode | |||||
import decagon.deep.layers | |||||
import numpy as np | |||||
import tensorflow as tf | |||||
import torch | |||||
def test_dedicom(): | |||||
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, | |||||
num_relation_types=7) | |||||
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, | |||||
edge_type=(0, 0)) | |||||
dedicom_tf.vars['global_interaction'] = \ | |||||
tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy()) | |||||
for i in range(dedicom_tf.num_types): | |||||
dedicom_tf.vars['local_variation_%d' % i] = \ | |||||
tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy()) | |||||
inputs = np.random.rand(20, 10).astype(np.float32) | |||||
inputs_torch = torch.tensor(inputs) | |||||
inputs_tf = { | |||||
0: tf.convert_to_tensor(inputs) | |||||
} | |||||
out_torch = dedicom_torch(inputs_torch, inputs_torch) | |||||
out_tf = dedicom_tf(inputs_tf) | |||||
assert len(out_torch) == len(out_tf) | |||||
assert len(out_tf) == 7 | |||||
for i in range(len(out_torch)): | |||||
assert out_torch[i].shape == out_tf[i].shape | |||||
sess = tf.Session() | |||||
for i in range(len(out_torch)): | |||||
item_torch = out_torch[i].detach().numpy() | |||||
item_tf = out_tf[i].eval(session=sess) | |||||
print('item_torch:', item_torch) | |||||
print('item_tf:', item_tf) | |||||
assert np.all(item_torch - item_tf < .000001) | |||||
sess.close() |