diff --git a/src/icosagon/decode.py b/src/icosagon/decode.py index 16efd83..dccf508 100644 --- a/src/icosagon/decode.py +++ b/src/icosagon/decode.py @@ -11,13 +11,13 @@ from .dropout import dropout class DEDICOMDecoder(torch.nn.Module): """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, drop_prob=0., + def __init__(self, input_dim, num_relation_types, keep_prob=1., activation=torch.sigmoid, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.num_relation_types = num_relation_types - self.drop_prob = drop_prob + self.keep_prob = keep_prob self.activation = activation self.global_interaction = init_glorot(input_dim, input_dim) @@ -29,8 +29,8 @@ class DEDICOMDecoder(torch.nn.Module): def forward(self, inputs_row, inputs_col): outputs = [] for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, 1.-self.drop_prob) - inputs_col = dropout(inputs_col, 1.-self.drop_prob) + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) relation = torch.diag(self.local_variation[k]) @@ -46,13 +46,13 @@ class DEDICOMDecoder(torch.nn.Module): class DistMultDecoder(torch.nn.Module): """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, drop_prob=0., + def __init__(self, input_dim, num_relation_types, keep_prob=1., activation=torch.sigmoid, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.num_relation_types = num_relation_types - self.drop_prob = drop_prob + self.keep_prob = keep_prob self.activation = activation self.relation = [ @@ -63,8 +63,8 @@ class DistMultDecoder(torch.nn.Module): def forward(self, inputs_row, inputs_col): outputs = [] for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, 1.-self.drop_prob) - inputs_col = dropout(inputs_col, 1.-self.drop_prob) + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) relation = torch.diag(self.relation[k]) @@ -78,13 +78,13 @@ class DistMultDecoder(torch.nn.Module): class BilinearDecoder(torch.nn.Module): """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, drop_prob=0., + def __init__(self, input_dim, num_relation_types, keep_prob=1., activation=torch.sigmoid, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.num_relation_types = num_relation_types - self.drop_prob = drop_prob + self.keep_prob = keep_prob self.activation = activation self.relation = [ @@ -95,8 +95,8 @@ class BilinearDecoder(torch.nn.Module): def forward(self, inputs_row, inputs_col): outputs = [] for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, 1.-self.drop_prob) - inputs_col = dropout(inputs_col, 1.-self.drop_prob) + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) intermediate_product = torch.mm(inputs_row, self.relation[k]) rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), @@ -108,21 +108,21 @@ class BilinearDecoder(torch.nn.Module): class InnerProductDecoder(torch.nn.Module): """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" - def __init__(self, input_dim, num_relation_types, drop_prob=0., + def __init__(self, input_dim, num_relation_types, keep_prob=1., activation=torch.sigmoid, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.num_relation_types = num_relation_types - self.drop_prob = drop_prob + self.keep_prob = keep_prob self.activation = activation def forward(self, inputs_row, inputs_col): outputs = [] for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, 1.-self.drop_prob) - inputs_col = dropout(inputs_col, 1.-self.drop_prob) + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py new file mode 100644 index 0000000..001f411 --- /dev/null +++ b/tests/icosagon/test_decode.py @@ -0,0 +1,86 @@ +from icosagon.decode import DEDICOMDecoder, \ + DistMultDecoder, \ + BilinearDecoder, \ + InnerProductDecoder +import decagon_pytorch.decode.pairwise +import torch + + +def test_dedicom_decoder_01(): + repr_ = torch.rand(20, 32) + dec_1 = DEDICOMDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + dec_2 = decagon_pytorch.decode.pairwise.DEDICOMDecoder(32, 7, drop_prob=0., + activation=torch.sigmoid) + dec_2.global_interaction = dec_1.global_interaction + dec_2.local_variation = dec_1.local_variation + + res_1 = dec_1(repr_, repr_) + res_2 = dec_2(repr_, repr_) + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert torch.all(res_1[i] == res_2[i]) + + +def test_dist_mult_decoder_01(): + repr_ = torch.rand(20, 32) + dec_1 = DistMultDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + dec_2 = decagon_pytorch.decode.pairwise.DistMultDecoder(32, 7, drop_prob=0., + activation=torch.sigmoid) + dec_2.relation = dec_1.relation + + res_1 = dec_1(repr_, repr_) + res_2 = dec_2(repr_, repr_) + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert torch.all(res_1[i] == res_2[i]) + + +def test_bilinear_decoder_01(): + repr_ = torch.rand(20, 32) + dec_1 = BilinearDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + dec_2 = decagon_pytorch.decode.pairwise.BilinearDecoder(32, 7, drop_prob=0., + activation=torch.sigmoid) + dec_2.relation = dec_1.relation + + res_1 = dec_1(repr_, repr_) + res_2 = dec_2(repr_, repr_) + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert torch.all(res_1[i] == res_2[i]) + + +def test_inner_product_decoder_01(): + repr_ = torch.rand(20, 32) + dec_1 = InnerProductDecoder(32, 7, keep_prob=1., + activation=torch.sigmoid) + dec_2 = decagon_pytorch.decode.pairwise.InnerProductDecoder(32, 7, drop_prob=0., + activation=torch.sigmoid) + + res_1 = dec_1(repr_, repr_) + res_2 = dec_2(repr_, repr_) + + assert isinstance(res_1, list) + assert isinstance(res_2, list) + + assert len(res_1) == len(res_2) + + for i in range(len(res_1)): + assert torch.all(res_1[i] == res_2[i])