diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py index 56ede10..a375f75 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data.py @@ -5,7 +5,6 @@ from collections import defaultdict -from .decode import BilinearDecoder from .weights import init_glorot diff --git a/src/decagon_pytorch/decode/__init__.py b/src/decagon_pytorch/decode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/decagon_pytorch/decode.py b/src/decagon_pytorch/decode/cartesian.py similarity index 95% rename from src/decagon_pytorch/decode.py rename to src/decagon_pytorch/decode/cartesian.py index 678eee2..910a8ce 100644 --- a/src/decagon_pytorch/decode.py +++ b/src/decagon_pytorch/decode/cartesian.py @@ -5,8 +5,8 @@ import torch -from .weights import init_glorot -from .dropout import dropout +from ..weights import init_glorot +from ..dropout import dropout class DEDICOMDecoder(torch.nn.Module): diff --git a/src/decagon_pytorch/decode/pairwise.py b/src/decagon_pytorch/decode/pairwise.py new file mode 100644 index 0000000..910a8ce --- /dev/null +++ b/src/decagon_pytorch/decode/pairwise.py @@ -0,0 +1,123 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from ..weights import init_glorot +from ..dropout import dropout + + +class DEDICOMDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + self.global_interaction = init_glorot(input_dim, input_dim) + self.local_variation = [ + torch.flatten(init_glorot(input_dim, 1)) \ + for _ in range(num_relation_types) + ] + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + relation = torch.diag(self.local_variation[k]) + + product1 = torch.mm(inputs_row, relation) + product2 = torch.mm(product1, self.global_interaction) + product3 = torch.mm(product2, relation) + rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs + + +class DistMultDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + self.relation = [ + torch.flatten(init_glorot(input_dim, 1)) \ + for _ in range(num_relation_types) + ] + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + relation = torch.diag(self.relation[k]) + + intermediate_product = torch.mm(inputs_row, relation) + rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs + + +class BilinearDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + self.relation = [ + init_glorot(input_dim, input_dim) \ + for _ in range(num_relation_types) + ] + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + intermediate_product = torch.mm(inputs_row, self.relation[k]) + rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs + + +class InnerProductDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs diff --git a/src/decagon_pytorch/layer/decode.py b/src/decagon_pytorch/layer/decode.py index 52a07b1..f354142 100644 --- a/src/decagon_pytorch/layer/decode.py +++ b/src/decagon_pytorch/layer/decode.py @@ -13,7 +13,7 @@ from typing import Type, \ Union, \ Dict, \ Tuple -from ..decode import DEDICOMDecoder +from ..decode.cartesian import DEDICOMDecoder class DecodeLayer(torch.nn.Module): diff --git a/tests/decagon_pytorch/layer/test_layer_decode.py b/tests/decagon_pytorch/layer/test_layer_decode.py index 57c554b..8c2e336 100644 --- a/tests/decagon_pytorch/layer/test_layer_decode.py +++ b/tests/decagon_pytorch/layer/test_layer_decode.py @@ -1,7 +1,7 @@ from decagon_pytorch.layer import OneHotInputLayer, \ DecagonLayer, \ DecodeLayer -from decagon_pytorch.decode import DEDICOMDecoder +from decagon_pytorch.decode.cartesian import DEDICOMDecoder from decagon_pytorch.data import Data import torch diff --git a/tests/decagon_pytorch/test_decode.py b/tests/decagon_pytorch/test_decode.py index 798fc50..de7d1a4 100644 --- a/tests/decagon_pytorch/test_decode.py +++ b/tests/decagon_pytorch/test_decode.py @@ -1,4 +1,4 @@ -import decagon_pytorch.decode +import decagon_pytorch.decode.cartesian import decagon.deep.layers import numpy as np import tensorflow as tf @@ -31,7 +31,7 @@ def _common(decoder_torch, decoder_tf): def test_dedicom_decoder(): - dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, + dedicom_torch = decagon_pytorch.decode.cartesian.DEDICOMDecoder(input_dim=10, num_relation_types=7) dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, edge_type=(0, 0)) @@ -46,7 +46,7 @@ def test_dedicom_decoder(): def test_dist_mult_decoder(): - distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10, + distmult_torch = decagon_pytorch.decode.cartesian.DistMultDecoder(input_dim=10, num_relation_types=7) distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7, edge_type=(0, 0)) @@ -59,7 +59,7 @@ def test_dist_mult_decoder(): def test_bilinear_decoder(): - bilinear_torch = decagon_pytorch.decode.BilinearDecoder(input_dim=10, + bilinear_torch = decagon_pytorch.decode.cartesian.BilinearDecoder(input_dim=10, num_relation_types=7) bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7, edge_type=(0, 0)) @@ -72,7 +72,7 @@ def test_bilinear_decoder(): def test_inner_product_decoder(): - inner_torch = decagon_pytorch.decode.InnerProductDecoder(input_dim=10, + inner_torch = decagon_pytorch.decode.cartesian.InnerProductDecoder(input_dim=10, num_relation_types=7) inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7, edge_type=(0, 0)) diff --git a/tests/decagon_pytorch/test_decode_dims.py b/tests/decagon_pytorch/test_decode_dims.py index 03d73cc..2c3a144 100644 --- a/tests/decagon_pytorch/test_decode_dims.py +++ b/tests/decagon_pytorch/test_decode_dims.py @@ -1,4 +1,4 @@ -from decagon_pytorch.decode import DEDICOMDecoder, \ +from decagon_pytorch.decode.cartesian import DEDICOMDecoder, \ DistMultDecoder, \ BilinearDecoder, \ InnerProductDecoder