| @@ -0,0 +1,37 @@ | |||||
| import torch | |||||
| from .weights import init_glorot | |||||
| from .dropout import dropout | |||||
| class DEDICOMDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.drop_prob = drop_prob | |||||
| self.activation = activation | |||||
| self.global_interaction = init_glorot(input_dim, input_dim) | |||||
| self.local_variation = [ | |||||
| torch.flatten(init_glorot(input_dim, 1)) \ | |||||
| for _ in range(num_relation_types) | |||||
| ] | |||||
| def forward(self, inputs_row, inputs_col): | |||||
| outputs = [] | |||||
| for k in range(self.num_relation_types): | |||||
| inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
| inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
| relation = torch.diag(self.local_variation[k]) | |||||
| product1 = torch.mm(inputs_row, relation) | |||||
| product2 = torch.mm(product1, self.global_interaction) | |||||
| product3 = torch.mm(product2, relation) | |||||
| rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) | |||||
| outputs.append(self.activation(rec)) | |||||
| return outputs | |||||
| @@ -0,0 +1,41 @@ | |||||
| import decagon_pytorch.decode | |||||
| import decagon.deep.layers | |||||
| import numpy as np | |||||
| import tensorflow as tf | |||||
| import torch | |||||
| def test_dedicom(): | |||||
| dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, | |||||
| num_relation_types=7) | |||||
| dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, | |||||
| edge_type=(0, 0)) | |||||
| dedicom_tf.vars['global_interaction'] = \ | |||||
| tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy()) | |||||
| for i in range(dedicom_tf.num_types): | |||||
| dedicom_tf.vars['local_variation_%d' % i] = \ | |||||
| tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy()) | |||||
| inputs = np.random.rand(20, 10).astype(np.float32) | |||||
| inputs_torch = torch.tensor(inputs) | |||||
| inputs_tf = { | |||||
| 0: tf.convert_to_tensor(inputs) | |||||
| } | |||||
| out_torch = dedicom_torch(inputs_torch, inputs_torch) | |||||
| out_tf = dedicom_tf(inputs_tf) | |||||
| assert len(out_torch) == len(out_tf) | |||||
| assert len(out_tf) == 7 | |||||
| for i in range(len(out_torch)): | |||||
| assert out_torch[i].shape == out_tf[i].shape | |||||
| sess = tf.Session() | |||||
| for i in range(len(out_torch)): | |||||
| item_torch = out_torch[i].detach().numpy() | |||||
| item_tf = out_tf[i].eval(session=sess) | |||||
| print('item_torch:', item_torch) | |||||
| print('item_tf:', item_tf) | |||||
| assert np.all(item_torch - item_tf < .000001) | |||||
| sess.close() | |||||