diff --git a/src/decagon_pytorch/decode.py b/src/decagon_pytorch/decode.py index 8ff6096..3643652 100644 --- a/src/decagon_pytorch/decode.py +++ b/src/decagon_pytorch/decode.py @@ -14,7 +14,6 @@ class DEDICOMDecoder(torch.nn.Module): self.drop_prob = drop_prob self.activation = activation - self.global_interaction = init_glorot(input_dim, input_dim) self.local_variation = [ torch.flatten(init_glorot(input_dim, 1)) \ @@ -35,3 +34,33 @@ class DEDICOMDecoder(torch.nn.Module): rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) outputs.append(self.activation(rec)) return outputs + + +class DistMultDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, drop_prob=0., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.drop_prob = drop_prob + self.activation = activation + + self.relation = [ + torch.flatten(init_glorot(input_dim, 1)) \ + for _ in range(num_relation_types) + ] + + def forward(self, inputs_row, inputs_col): + outputs = [] + for k in range(self.num_relation_types): + inputs_row = dropout(inputs_row, 1.-self.drop_prob) + inputs_col = dropout(inputs_col, 1.-self.drop_prob) + + relation = torch.diag(self.relation[k]) + + intermediate_product = torch.mm(inputs_row, relation) + rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) + outputs.append(self.activation(rec)) + return outputs diff --git a/tests/decagon_pytorch/test_decode.py b/tests/decagon_pytorch/test_decode.py index 7e8c4ed..a374480 100644 --- a/tests/decagon_pytorch/test_decode.py +++ b/tests/decagon_pytorch/test_decode.py @@ -5,25 +5,14 @@ import tensorflow as tf import torch -def test_dedicom(): - dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, - num_relation_types=7) - dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, - edge_type=(0, 0)) - - dedicom_tf.vars['global_interaction'] = \ - tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy()) - for i in range(dedicom_tf.num_types): - dedicom_tf.vars['local_variation_%d' % i] = \ - tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy()) - +def _common(decoder_torch, decoder_tf): inputs = np.random.rand(20, 10).astype(np.float32) inputs_torch = torch.tensor(inputs) inputs_tf = { 0: tf.convert_to_tensor(inputs) } - out_torch = dedicom_torch(inputs_torch, inputs_torch) - out_tf = dedicom_tf(inputs_tf) + out_torch = decoder_torch(inputs_torch, inputs_torch) + out_tf = decoder_tf(inputs_tf) assert len(out_torch) == len(out_tf) assert len(out_tf) == 7 @@ -39,3 +28,31 @@ def test_dedicom(): print('item_tf:', item_tf) assert np.all(item_torch - item_tf < .000001) sess.close() + + +def test_dedicom_decoder(): + dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, + num_relation_types=7) + dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, + edge_type=(0, 0)) + + dedicom_tf.vars['global_interaction'] = \ + tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy()) + for i in range(dedicom_tf.num_types): + dedicom_tf.vars['local_variation_%d' % i] = \ + tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy()) + + _common(dedicom_torch, dedicom_tf) + + +def test_dist_mult_decoder(): + distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10, + num_relation_types=7) + distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7, + edge_type=(0, 0)) + + for i in range(distmult_tf.num_types): + distmult_tf.vars['relation_%d' % i] = \ + tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy()) + + _common(distmult_torch, distmult_tf)