IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Преглед на файлове

Added InnerProductDecoder and test.

master
Stanislaw Adaszewski преди 4 години
родител
ревизия
e26ccd4222
променени са 2 файла, в които са добавени 32 реда и са изтрити 0 реда
  1. +23
    -0
      src/decagon_pytorch/decode.py
  2. +9
    -0
      tests/decagon_pytorch/test_decode.py

+ 23
- 0
src/decagon_pytorch/decode.py Целия файл

@@ -92,3 +92,26 @@ class BilinearDecoder(torch.nn.Module):
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs
class InnerProductDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.activation = activation
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1))
outputs.append(self.activation(rec))
return outputs

+ 9
- 0
tests/decagon_pytorch/test_decode.py Целия файл

@@ -69,3 +69,12 @@ def test_bilinear_decoder():
tf.convert_to_tensor(bilinear_torch.relation[i].detach().numpy())
_common(bilinear_torch, bilinear_tf)
def test_inner_product_decoder():
inner_torch = decagon_pytorch.decode.InnerProductDecoder(input_dim=10,
num_relation_types=7)
inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7,
edge_type=(0, 0))
_common(inner_torch, inner_tf)

Loading…
Отказ
Запис