From 438ba67565ccf4ff3223119f966e047820d9adf5 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 14 May 2020 19:02:08 +0200 Subject: [PATCH] Add BilinearDecoder and test. --- src/decagon_pytorch/decode.py | 28 ++++++++++++++++++++++++++++ tests/decagon_pytorch/test_decode.py | 13 +++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/decagon_pytorch/decode.py b/src/decagon_pytorch/decode.py index 3643652..8538c52 100644 --- a/src/decagon_pytorch/decode.py +++ b/src/decagon_pytorch/decode.py @@ -64,3 +64,31 @@ class DistMultDecoder(torch.nn.Module): rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) outputs.append(self.activation(rec)) return outputs + + +class BilinearDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + self.relation = [ + init_glorot(input_dim, input_dim) \ + for _ in range(num_relation_types) + ] + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + intermediate_product = torch.mm(inputs_row, self.relation[k]) + rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs diff --git a/tests/decagon_pytorch/test_decode.py b/tests/decagon_pytorch/test_decode.py index a374480..198c6fe 100644 --- a/tests/decagon_pytorch/test_decode.py +++ b/tests/decagon_pytorch/test_decode.py @@ -56,3 +56,16 @@ def test_dist_mult_decoder(): tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy()) _common(distmult_torch, distmult_tf) + + +def test_bilinear_decoder(): + bilinear_torch = decagon_pytorch.decode.BilinearDecoder(input_dim=10, + num_relation_types=7) + bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7, + edge_type=(0, 0)) + + for i in range(bilinear_tf.num_types): + bilinear_tf.vars['relation_%d' % i] = \ + tf.convert_to_tensor(bilinear_torch.relation[i].detach().numpy()) + + _common(bilinear_torch, bilinear_tf)