IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Added DistMultDecoder with test.

master
Stanislaw Adaszewski 4 years ago
parent
commit
c4aefc0150
2 changed files with 61 additions and 15 deletions
  1. +30
    -1
      src/decagon_pytorch/decode.py
  2. +31
    -14
      tests/decagon_pytorch/test_decode.py

+ 30
- 1
src/decagon_pytorch/decode.py View File

@@ -14,7 +14,6 @@ class DEDICOMDecoder(torch.nn.Module):
self.drop_prob = drop_prob
self.activation = activation
self.global_interaction = init_glorot(input_dim, input_dim)
self.local_variation = [
torch.flatten(init_glorot(input_dim, 1)) \
@@ -35,3 +34,33 @@ class DEDICOMDecoder(torch.nn.Module):
rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs
class DistMultDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.activation = activation
self.relation = [
torch.flatten(init_glorot(input_dim, 1)) \
for _ in range(num_relation_types)
]
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
relation = torch.diag(self.relation[k])
intermediate_product = torch.mm(inputs_row, relation)
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs

+ 31
- 14
tests/decagon_pytorch/test_decode.py View File

@@ -5,25 +5,14 @@ import tensorflow as tf
import torch
def test_dedicom():
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
num_relation_types=7)
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
edge_type=(0, 0))
dedicom_tf.vars['global_interaction'] = \
tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
for i in range(dedicom_tf.num_types):
dedicom_tf.vars['local_variation_%d' % i] = \
tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
def _common(decoder_torch, decoder_tf):
inputs = np.random.rand(20, 10).astype(np.float32)
inputs_torch = torch.tensor(inputs)
inputs_tf = {
0: tf.convert_to_tensor(inputs)
}
out_torch = dedicom_torch(inputs_torch, inputs_torch)
out_tf = dedicom_tf(inputs_tf)
out_torch = decoder_torch(inputs_torch, inputs_torch)
out_tf = decoder_tf(inputs_tf)
assert len(out_torch) == len(out_tf)
assert len(out_tf) == 7
@@ -39,3 +28,31 @@ def test_dedicom():
print('item_tf:', item_tf)
assert np.all(item_torch - item_tf < .000001)
sess.close()
def test_dedicom_decoder():
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
num_relation_types=7)
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
edge_type=(0, 0))
dedicom_tf.vars['global_interaction'] = \
tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
for i in range(dedicom_tf.num_types):
dedicom_tf.vars['local_variation_%d' % i] = \
tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
_common(dedicom_torch, dedicom_tf)
def test_dist_mult_decoder():
distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10,
num_relation_types=7)
distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7,
edge_type=(0, 0))
for i in range(distmult_tf.num_types):
distmult_tf.vars['relation_%d' % i] = \
tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
_common(distmult_torch, distmult_tf)

Loading…
Cancel
Save