IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Added DEDICOMDecoder with test.

master
Stanislaw Adaszewski 4 years ago
parent
commit
353fc8913e
2 changed files with 78 additions and 0 deletions
  1. +37
    -0
      src/decagon_pytorch/decode.py
  2. +41
    -0
      tests/decagon_pytorch/test_decode.py

+ 37
- 0
src/decagon_pytorch/decode.py View File

@@ -0,0 +1,37 @@
import torch
from .weights import init_glorot
from .dropout import dropout
class DEDICOMDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.activation = activation
self.global_interaction = init_glorot(input_dim, input_dim)
self.local_variation = [
torch.flatten(init_glorot(input_dim, 1)) \
for _ in range(num_relation_types)
]
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
relation = torch.diag(self.local_variation[k])
product1 = torch.mm(inputs_row, relation)
product2 = torch.mm(product1, self.global_interaction)
product3 = torch.mm(product2, relation)
rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs

+ 41
- 0
tests/decagon_pytorch/test_decode.py View File

@@ -0,0 +1,41 @@
import decagon_pytorch.decode
import decagon.deep.layers
import numpy as np
import tensorflow as tf
import torch
def test_dedicom():
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
num_relation_types=7)
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
edge_type=(0, 0))
dedicom_tf.vars['global_interaction'] = \
tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
for i in range(dedicom_tf.num_types):
dedicom_tf.vars['local_variation_%d' % i] = \
tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
inputs = np.random.rand(20, 10).astype(np.float32)
inputs_torch = torch.tensor(inputs)
inputs_tf = {
0: tf.convert_to_tensor(inputs)
}
out_torch = dedicom_torch(inputs_torch, inputs_torch)
out_tf = dedicom_tf(inputs_tf)
assert len(out_torch) == len(out_tf)
assert len(out_tf) == 7
for i in range(len(out_torch)):
assert out_torch[i].shape == out_tf[i].shape
sess = tf.Session()
for i in range(len(out_torch)):
item_torch = out_torch[i].detach().numpy()
item_tf = out_tf[i].eval(session=sess)
print('item_torch:', item_torch)
print('item_tf:', item_tf)
assert np.all(item_torch - item_tf < .000001)
sess.close()

Loading…
Cancel
Save