|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import torch
- from ..weights import init_glorot
- from ..dropout import dropout
-
-
- class DEDICOMDecoder(torch.nn.Module):
- """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
- def __init__(self, input_dim, num_relation_types, drop_prob=0.,
- activation=torch.sigmoid, **kwargs):
-
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_relation_types = num_relation_types
- self.drop_prob = drop_prob
- self.activation = activation
-
- self.global_interaction = init_glorot(input_dim, input_dim)
- self.local_variation = [
- torch.flatten(init_glorot(input_dim, 1)) \
- for _ in range(num_relation_types)
- ]
-
- def forward(self, inputs_row, inputs_col):
- outputs = []
- for k in range(self.num_relation_types):
- inputs_row = dropout(inputs_row, 1.-self.drop_prob)
- inputs_col = dropout(inputs_col, 1.-self.drop_prob)
-
- relation = torch.diag(self.local_variation[k])
-
- product1 = torch.mm(inputs_row, relation)
- product2 = torch.mm(product1, self.global_interaction)
- product3 = torch.mm(product2, relation)
- rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
- outputs.append(self.activation(rec))
- return outputs
-
-
- class DistMultDecoder(torch.nn.Module):
- """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
- def __init__(self, input_dim, num_relation_types, drop_prob=0.,
- activation=torch.sigmoid, **kwargs):
-
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_relation_types = num_relation_types
- self.drop_prob = drop_prob
- self.activation = activation
-
- self.relation = [
- torch.flatten(init_glorot(input_dim, 1)) \
- for _ in range(num_relation_types)
- ]
-
- def forward(self, inputs_row, inputs_col):
- outputs = []
- for k in range(self.num_relation_types):
- inputs_row = dropout(inputs_row, 1.-self.drop_prob)
- inputs_col = dropout(inputs_col, 1.-self.drop_prob)
-
- relation = torch.diag(self.relation[k])
-
- intermediate_product = torch.mm(inputs_row, relation)
- rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
- outputs.append(self.activation(rec))
- return outputs
-
-
- class BilinearDecoder(torch.nn.Module):
- """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
- def __init__(self, input_dim, num_relation_types, drop_prob=0.,
- activation=torch.sigmoid, **kwargs):
-
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_relation_types = num_relation_types
- self.drop_prob = drop_prob
- self.activation = activation
-
- self.relation = [
- init_glorot(input_dim, input_dim) \
- for _ in range(num_relation_types)
- ]
-
- def forward(self, inputs_row, inputs_col):
- outputs = []
- for k in range(self.num_relation_types):
- inputs_row = dropout(inputs_row, 1.-self.drop_prob)
- inputs_col = dropout(inputs_col, 1.-self.drop_prob)
-
- intermediate_product = torch.mm(inputs_row, self.relation[k])
- rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
- outputs.append(self.activation(rec))
- return outputs
-
-
- class InnerProductDecoder(torch.nn.Module):
- """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
- def __init__(self, input_dim, num_relation_types, drop_prob=0.,
- activation=torch.sigmoid, **kwargs):
-
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_relation_types = num_relation_types
- self.drop_prob = drop_prob
- self.activation = activation
-
-
- def forward(self, inputs_row, inputs_col):
- outputs = []
- for k in range(self.num_relation_types):
- inputs_row = dropout(inputs_row, 1.-self.drop_prob)
- inputs_col = dropout(inputs_col, 1.-self.drop_prob)
-
- rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1))
- outputs.append(self.activation(rec))
- return outputs
|