IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

42 líneas
1.4KB

  1. import decagon_pytorch.decode
  2. import decagon.deep.layers
  3. import numpy as np
  4. import tensorflow as tf
  5. import torch
  6. def test_dedicom():
  7. dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
  8. num_relation_types=7)
  9. dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
  10. edge_type=(0, 0))
  11. dedicom_tf.vars['global_interaction'] = \
  12. tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
  13. for i in range(dedicom_tf.num_types):
  14. dedicom_tf.vars['local_variation_%d' % i] = \
  15. tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
  16. inputs = np.random.rand(20, 10).astype(np.float32)
  17. inputs_torch = torch.tensor(inputs)
  18. inputs_tf = {
  19. 0: tf.convert_to_tensor(inputs)
  20. }
  21. out_torch = dedicom_torch(inputs_torch, inputs_torch)
  22. out_tf = dedicom_tf(inputs_tf)
  23. assert len(out_torch) == len(out_tf)
  24. assert len(out_tf) == 7
  25. for i in range(len(out_torch)):
  26. assert out_torch[i].shape == out_tf[i].shape
  27. sess = tf.Session()
  28. for i in range(len(out_torch)):
  29. item_torch = out_torch[i].detach().numpy()
  30. item_tf = out_tf[i].eval(session=sess)
  31. print('item_torch:', item_torch)
  32. print('item_tf:', item_tf)
  33. assert np.all(item_torch - item_tf < .000001)
  34. sess.close()