IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add tests for decode.

master
Stanislaw Adaszewski 3 years ago
parent
commit
dd0eb81251
2 changed files with 102 additions and 16 deletions
  1. +16
    -16
      src/icosagon/decode.py
  2. +86
    -0
      tests/icosagon/test_decode.py

+ 16
- 16
src/icosagon/decode.py View File

@@ -11,13 +11,13 @@ from .dropout import dropout
class DEDICOMDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.keep_prob = keep_prob
self.activation = activation
self.global_interaction = init_glorot(input_dim, input_dim)
@@ -29,8 +29,8 @@ class DEDICOMDecoder(torch.nn.Module):
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
relation = torch.diag(self.local_variation[k])
@@ -46,13 +46,13 @@ class DEDICOMDecoder(torch.nn.Module):
class DistMultDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
@@ -63,8 +63,8 @@ class DistMultDecoder(torch.nn.Module):
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
relation = torch.diag(self.relation[k])
@@ -78,13 +78,13 @@ class DistMultDecoder(torch.nn.Module):
class BilinearDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
@@ -95,8 +95,8 @@ class BilinearDecoder(torch.nn.Module):
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
intermediate_product = torch.mm(inputs_row, self.relation[k])
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
@@ -108,21 +108,21 @@ class BilinearDecoder(torch.nn.Module):
class InnerProductDecoder(torch.nn.Module):
"""DEDICOM Tensor Factorization Decoder model layer for link prediction."""
def __init__(self, input_dim, num_relation_types, drop_prob=0.,
def __init__(self, input_dim, num_relation_types, keep_prob=1.,
activation=torch.sigmoid, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_relation_types = num_relation_types
self.drop_prob = drop_prob
self.keep_prob = keep_prob
self.activation = activation
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))


+ 86
- 0
tests/icosagon/test_decode.py View File

@@ -0,0 +1,86 @@
from icosagon.decode import DEDICOMDecoder, \
DistMultDecoder, \
BilinearDecoder, \
InnerProductDecoder
import decagon_pytorch.decode.pairwise
import torch
def test_dedicom_decoder_01():
repr_ = torch.rand(20, 32)
dec_1 = DEDICOMDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
dec_2 = decagon_pytorch.decode.pairwise.DEDICOMDecoder(32, 7, drop_prob=0.,
activation=torch.sigmoid)
dec_2.global_interaction = dec_1.global_interaction
dec_2.local_variation = dec_1.local_variation
res_1 = dec_1(repr_, repr_)
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_dist_mult_decoder_01():
repr_ = torch.rand(20, 32)
dec_1 = DistMultDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
dec_2 = decagon_pytorch.decode.pairwise.DistMultDecoder(32, 7, drop_prob=0.,
activation=torch.sigmoid)
dec_2.relation = dec_1.relation
res_1 = dec_1(repr_, repr_)
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_bilinear_decoder_01():
repr_ = torch.rand(20, 32)
dec_1 = BilinearDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
dec_2 = decagon_pytorch.decode.pairwise.BilinearDecoder(32, 7, drop_prob=0.,
activation=torch.sigmoid)
dec_2.relation = dec_1.relation
res_1 = dec_1(repr_, repr_)
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])
def test_inner_product_decoder_01():
repr_ = torch.rand(20, 32)
dec_1 = InnerProductDecoder(32, 7, keep_prob=1.,
activation=torch.sigmoid)
dec_2 = decagon_pytorch.decode.pairwise.InnerProductDecoder(32, 7, drop_prob=0.,
activation=torch.sigmoid)
res_1 = dec_1(repr_, repr_)
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
assert isinstance(res_2, list)
assert len(res_1) == len(res_2)
for i in range(len(res_1)):
assert torch.all(res_1[i] == res_2[i])

Loading…
Cancel
Save