From 353fc8913e6ab3706d1298931ddb82ff2ed93fc7 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 13 May 2020 17:59:05 +0200 Subject: [PATCH] Added DEDICOMDecoder with test. --- src/decagon_pytorch/decode.py | 37 +++++++++++++++++++++++++ tests/decagon_pytorch/test_decode.py | 41 ++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 src/decagon_pytorch/decode.py create mode 100644 tests/decagon_pytorch/test_decode.py diff --git a/src/decagon_pytorch/decode.py b/src/decagon_pytorch/decode.py new file mode 100644 index 0000000..8ff6096 --- /dev/null +++ b/src/decagon_pytorch/decode.py @@ -0,0 +1,37 @@ +import torch +from .weights import init_glorot +from .dropout import dropout + + +class DEDICOMDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + + self.global_interaction = init_glorot(input_dim, input_dim) + self.local_variation = [ + torch.flatten(init_glorot(input_dim, 1)) \ + for _ in range(num_relation_types) + ] + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + relation = torch.diag(self.local_variation[k]) + + product1 = torch.mm(inputs_row, relation) + product2 = torch.mm(product1, self.global_interaction) + product3 = torch.mm(product2, relation) + rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs diff --git a/tests/decagon_pytorch/test_decode.py b/tests/decagon_pytorch/test_decode.py new file mode 100644 index 0000000..7e8c4ed --- /dev/null +++ b/tests/decagon_pytorch/test_decode.py @@ -0,0 +1,41 @@ +import decagon_pytorch.decode +import decagon.deep.layers +import numpy as np +import tensorflow as tf +import torch + + +def test_dedicom(): + dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, + num_relation_types=7) + dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, + edge_type=(0, 0)) + + dedicom_tf.vars['global_interaction'] = \ + tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy()) + for i in range(dedicom_tf.num_types): + dedicom_tf.vars['local_variation_%d' % i] = \ + tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy()) + + inputs = np.random.rand(20, 10).astype(np.float32) + inputs_torch = torch.tensor(inputs) + inputs_tf = { + 0: tf.convert_to_tensor(inputs) + } + out_torch = dedicom_torch(inputs_torch, inputs_torch) + out_tf = dedicom_tf(inputs_tf) + + assert len(out_torch) == len(out_tf) + assert len(out_tf) == 7 + + for i in range(len(out_torch)): + assert out_torch[i].shape == out_tf[i].shape + + sess = tf.Session() + for i in range(len(out_torch)): + item_torch = out_torch[i].detach().numpy() + item_tf = out_tf[i].eval(session=sess) + print('item_torch:', item_torch) + print('item_tf:', item_tf) + assert np.all(item_torch - item_tf < .000001) + sess.close()