import decagon_pytorch.decode import decagon.deep.layers import numpy as np import tensorflow as tf import torch def test_dedicom(): dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10, num_relation_types=7) dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7, edge_type=(0, 0)) dedicom_tf.vars['global_interaction'] = \ tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy()) for i in range(dedicom_tf.num_types): dedicom_tf.vars['local_variation_%d' % i] = \ tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy()) inputs = np.random.rand(20, 10).astype(np.float32) inputs_torch = torch.tensor(inputs) inputs_tf = { 0: tf.convert_to_tensor(inputs) } out_torch = dedicom_torch(inputs_torch, inputs_torch) out_tf = dedicom_tf(inputs_tf) assert len(out_torch) == len(out_tf) assert len(out_tf) == 7 for i in range(len(out_torch)): assert out_torch[i].shape == out_tf[i].shape sess = tf.Session() for i in range(len(out_torch)): item_torch = out_torch[i].detach().numpy() item_tf = out_tf[i].eval(session=sess) print('item_torch:', item_torch) print('item_tf:', item_tf) assert np.all(item_torch - item_tf < .000001) sess.close()