|
- import decagon_pytorch.decode
- import decagon.deep.layers
- import numpy as np
- import tensorflow as tf
- import torch
-
-
- def _common(decoder_torch, decoder_tf):
- inputs = np.random.rand(20, 10).astype(np.float32)
- inputs_torch = torch.tensor(inputs)
- inputs_tf = {
- 0: tf.convert_to_tensor(inputs)
- }
- out_torch = decoder_torch(inputs_torch, inputs_torch)
- out_tf = decoder_tf(inputs_tf)
-
- assert len(out_torch) == len(out_tf)
- assert len(out_tf) == 7
-
- for i in range(len(out_torch)):
- assert out_torch[i].shape == out_tf[i].shape
-
- sess = tf.Session()
- for i in range(len(out_torch)):
- item_torch = out_torch[i].detach().numpy()
- item_tf = out_tf[i].eval(session=sess)
- print('item_torch:', item_torch)
- print('item_tf:', item_tf)
- assert np.all(item_torch - item_tf < .000001)
- sess.close()
-
-
- def test_dedicom_decoder():
- dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
- num_relation_types=7)
- dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
- edge_type=(0, 0))
-
- dedicom_tf.vars['global_interaction'] = \
- tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
- for i in range(dedicom_tf.num_types):
- dedicom_tf.vars['local_variation_%d' % i] = \
- tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
-
- _common(dedicom_torch, dedicom_tf)
-
-
- def test_dist_mult_decoder():
- distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10,
- num_relation_types=7)
- distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7,
- edge_type=(0, 0))
-
- for i in range(distmult_tf.num_types):
- distmult_tf.vars['relation_%d' % i] = \
- tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
-
- _common(distmult_torch, distmult_tf)
-
-
- def test_bilinear_decoder():
- bilinear_torch = decagon_pytorch.decode.BilinearDecoder(input_dim=10,
- num_relation_types=7)
- bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7,
- edge_type=(0, 0))
-
- for i in range(bilinear_tf.num_types):
- bilinear_tf.vars['relation_%d' % i] = \
- tf.convert_to_tensor(bilinear_torch.relation[i].detach().numpy())
-
- _common(bilinear_torch, bilinear_tf)
-
-
- def test_inner_product_decoder():
- inner_torch = decagon_pytorch.decode.InnerProductDecoder(input_dim=10,
- num_relation_types=7)
- inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7,
- edge_type=(0, 0))
-
- _common(inner_torch, inner_tf)
|