IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

59 rindas
1.9KB

  1. import decagon_pytorch.decode
  2. import decagon.deep.layers
  3. import numpy as np
  4. import tensorflow as tf
  5. import torch
  6. def _common(decoder_torch, decoder_tf):
  7. inputs = np.random.rand(20, 10).astype(np.float32)
  8. inputs_torch = torch.tensor(inputs)
  9. inputs_tf = {
  10. 0: tf.convert_to_tensor(inputs)
  11. }
  12. out_torch = decoder_torch(inputs_torch, inputs_torch)
  13. out_tf = decoder_tf(inputs_tf)
  14. assert len(out_torch) == len(out_tf)
  15. assert len(out_tf) == 7
  16. for i in range(len(out_torch)):
  17. assert out_torch[i].shape == out_tf[i].shape
  18. sess = tf.Session()
  19. for i in range(len(out_torch)):
  20. item_torch = out_torch[i].detach().numpy()
  21. item_tf = out_tf[i].eval(session=sess)
  22. print('item_torch:', item_torch)
  23. print('item_tf:', item_tf)
  24. assert np.all(item_torch - item_tf < .000001)
  25. sess.close()
  26. def test_dedicom_decoder():
  27. dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
  28. num_relation_types=7)
  29. dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
  30. edge_type=(0, 0))
  31. dedicom_tf.vars['global_interaction'] = \
  32. tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
  33. for i in range(dedicom_tf.num_types):
  34. dedicom_tf.vars['local_variation_%d' % i] = \
  35. tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
  36. _common(dedicom_torch, dedicom_tf)
  37. def test_dist_mult_decoder():
  38. distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10,
  39. num_relation_types=7)
  40. distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7,
  41. edge_type=(0, 0))
  42. for i in range(distmult_tf.num_types):
  43. distmult_tf.vars['relation_%d' % i] = \
  44. tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
  45. _common(distmult_torch, distmult_tf)