IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add BilinearDecoder and test.

master
Stanislaw Adaszewski 4 years ago
parent
commit
438ba67565
2 changed files with 41 additions and 0 deletions
  1. +28
    -0
      src/decagon_pytorch/decode.py
  2. +13
    -0
      tests/decagon_pytorch/test_decode.py

+ 28
- 0
src/decagon_pytorch/decode.py View File

@@ -64,3 +64,31 @@ class DistMultDecoder(torch.nn.Module):
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs
class BilinearDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.activation = activation
self.relation = [
init_glorot(input_dim, input_dim) \
for _ in range(num_relation_types)
]
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
intermediate_product = torch.mm(inputs_row, self.relation[k])
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs

+ 13
- 0
tests/decagon_pytorch/test_decode.py View File

@@ -56,3 +56,16 @@ def test_dist_mult_decoder():
tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
_common(distmult_torch, distmult_tf)
def test_bilinear_decoder():
bilinear_torch = decagon_pytorch.decode.BilinearDecoder(input_dim=10,
num_relation_types=7)
bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7,
edge_type=(0, 0))
for i in range(bilinear_tf.num_types):
bilinear_tf.vars['relation_%d' % i] = \
tf.convert_to_tensor(bilinear_torch.relation[i].detach().numpy())
_common(bilinear_torch, bilinear_tf)

Loading…
Cancel
Save