@@ -5,7 +5,6 @@ | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from .decode import BilinearDecoder | |||||
from .weights import init_glorot | from .weights import init_glorot | ||||
@@ -5,8 +5,8 @@ | |||||
import torch | import torch | ||||
from .weights import init_glorot | |||||
from .dropout import dropout | |||||
from ..weights import init_glorot | |||||
from ..dropout import dropout | |||||
class DEDICOMDecoder(torch.nn.Module): | class DEDICOMDecoder(torch.nn.Module): |
@@ -0,0 +1,123 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from ..weights import init_glorot | |||||
from ..dropout import dropout | |||||
class DEDICOMDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.drop_prob = drop_prob | |||||
self.activation = activation | |||||
self.global_interaction = init_glorot(input_dim, input_dim) | |||||
self.local_variation = [ | |||||
torch.flatten(init_glorot(input_dim, 1)) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
def forward(self, inputs_row, inputs_col): | |||||
outputs = [] | |||||
for k in range(self.num_relation_types): | |||||
inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
relation = torch.diag(self.local_variation[k]) | |||||
product1 = torch.mm(inputs_row, relation) | |||||
product2 = torch.mm(product1, self.global_interaction) | |||||
product3 = torch.mm(product2, relation) | |||||
rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) | |||||
outputs.append(self.activation(rec)) | |||||
return outputs | |||||
class DistMultDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.drop_prob = drop_prob | |||||
self.activation = activation | |||||
self.relation = [ | |||||
torch.flatten(init_glorot(input_dim, 1)) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
def forward(self, inputs_row, inputs_col): | |||||
outputs = [] | |||||
for k in range(self.num_relation_types): | |||||
inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
relation = torch.diag(self.relation[k]) | |||||
intermediate_product = torch.mm(inputs_row, relation) | |||||
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) | |||||
outputs.append(self.activation(rec)) | |||||
return outputs | |||||
class BilinearDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.drop_prob = drop_prob | |||||
self.activation = activation | |||||
self.relation = [ | |||||
init_glorot(input_dim, input_dim) \ | |||||
for _ in range(num_relation_types) | |||||
] | |||||
def forward(self, inputs_row, inputs_col): | |||||
outputs = [] | |||||
for k in range(self.num_relation_types): | |||||
inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
intermediate_product = torch.mm(inputs_row, self.relation[k]) | |||||
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) | |||||
outputs.append(self.activation(rec)) | |||||
return outputs | |||||
class InnerProductDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.drop_prob = drop_prob | |||||
self.activation = activation | |||||
def forward(self, inputs_row, inputs_col): | |||||
outputs = [] | |||||
for k in range(self.num_relation_types): | |||||
inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1)) | |||||
outputs.append(self.activation(rec)) | |||||
return outputs |
@@ -13,7 +13,7 @@ from typing import Type, \ | |||||
Union, \ | Union, \ | ||||
Dict, \ | Dict, \ | ||||
Tuple | Tuple | ||||
from ..decode import DEDICOMDecoder | |||||
from ..decode.cartesian import DEDICOMDecoder | |||||
class DecodeLayer(torch.nn.Module): | class DecodeLayer(torch.nn.Module): | ||||
@@ -1,7 +1,7 @@ | |||||
from decagon_pytorch.layer import OneHotInputLayer, \ | from decagon_pytorch.layer import OneHotInputLayer, \ | ||||
DecagonLayer, \ | DecagonLayer, \ | ||||
DecodeLayer | DecodeLayer | ||||
from decagon_pytorch.decode import DEDICOMDecoder | |||||
from decagon_pytorch.decode.cartesian import DEDICOMDecoder | |||||
from decagon_pytorch.data import Data | from decagon_pytorch.data import Data | ||||
import torch | import torch | ||||
@@ -1,4 +1,4 @@ | |||||
import decagon_pytorch.decode | |||||
import decagon_pytorch.decode.cartesian | |||||
import decagon.deep.layers | import decagon.deep.layers | ||||
import numpy as np | import numpy as np | ||||
import tensorflow as tf | import tensorflow as tf | ||||
@@ -31,7 +31,7 @@ def _common(decoder_torch, decoder_tf): | |||||
def test_dedicom_decoder(): | def test_dedicom_decoder(): | ||||
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, | |||||
dedicom_torch = decagon_pytorch.decode.cartesian.DEDICOMDecoder(input_dim=10, | |||||
num_relation_types=7) | num_relation_types=7) | ||||
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, | dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, | ||||
edge_type=(0, 0)) | edge_type=(0, 0)) | ||||
@@ -46,7 +46,7 @@ def test_dedicom_decoder(): | |||||
def test_dist_mult_decoder(): | def test_dist_mult_decoder(): | ||||
distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10, | |||||
distmult_torch = decagon_pytorch.decode.cartesian.DistMultDecoder(input_dim=10, | |||||
num_relation_types=7) | num_relation_types=7) | ||||
distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7, | distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7, | ||||
edge_type=(0, 0)) | edge_type=(0, 0)) | ||||
@@ -59,7 +59,7 @@ def test_dist_mult_decoder(): | |||||
def test_bilinear_decoder(): | def test_bilinear_decoder(): | ||||
bilinear_torch = decagon_pytorch.decode.BilinearDecoder(input_dim=10, | |||||
bilinear_torch = decagon_pytorch.decode.cartesian.BilinearDecoder(input_dim=10, | |||||
num_relation_types=7) | num_relation_types=7) | ||||
bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7, | bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7, | ||||
edge_type=(0, 0)) | edge_type=(0, 0)) | ||||
@@ -72,7 +72,7 @@ def test_bilinear_decoder(): | |||||
def test_inner_product_decoder(): | def test_inner_product_decoder(): | ||||
inner_torch = decagon_pytorch.decode.InnerProductDecoder(input_dim=10, | |||||
inner_torch = decagon_pytorch.decode.cartesian.InnerProductDecoder(input_dim=10, | |||||
num_relation_types=7) | num_relation_types=7) | ||||
inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7, | inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7, | ||||
edge_type=(0, 0)) | edge_type=(0, 0)) | ||||
@@ -1,4 +1,4 @@ | |||||
from decagon_pytorch.decode import DEDICOMDecoder, \ | |||||
from decagon_pytorch.decode.cartesian import DEDICOMDecoder, \ | |||||
DistMultDecoder, \ | DistMultDecoder, \ | ||||
BilinearDecoder, \ | BilinearDecoder, \ | ||||
InnerProductDecoder | InnerProductDecoder | ||||