IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_decode.py 2.7KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import decagon_pytorch.decode.cartesian
  2. import decagon.deep.layers
  3. import numpy as np
  4. import tensorflow as tf
  5. import torch
  6. def _common(decoder_torch, decoder_tf):
  7. inputs = np.random.rand(20, 10).astype(np.float32)
  8. inputs_torch = torch.tensor(inputs)
  9. inputs_tf = {
  10. 0: tf.convert_to_tensor(inputs)
  11. }
  12. out_torch = decoder_torch(inputs_torch, inputs_torch)
  13. out_tf = decoder_tf(inputs_tf)
  14. assert len(out_torch) == len(out_tf)
  15. assert len(out_tf) == 7
  16. for i in range(len(out_torch)):
  17. assert out_torch[i].shape == out_tf[i].shape
  18. sess = tf.Session()
  19. for i in range(len(out_torch)):
  20. item_torch = out_torch[i].detach().numpy()
  21. item_tf = out_tf[i].eval(session=sess)
  22. print('item_torch:', item_torch)
  23. print('item_tf:', item_tf)
  24. assert np.all(item_torch - item_tf < .000001)
  25. sess.close()
  26. def test_dedicom_decoder():
  27. dedicom_torch = decagon_pytorch.decode.cartesian.DEDICOMDecoder(input_dim=10,
  28. num_relation_types=7)
  29. dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
  30. edge_type=(0, 0))
  31. dedicom_tf.vars['global_interaction'] = \
  32. tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
  33. for i in range(dedicom_tf.num_types):
  34. dedicom_tf.vars['local_variation_%d' % i] = \
  35. tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
  36. _common(dedicom_torch, dedicom_tf)
  37. def test_dist_mult_decoder():
  38. distmult_torch = decagon_pytorch.decode.cartesian.DistMultDecoder(input_dim=10,
  39. num_relation_types=7)
  40. distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7,
  41. edge_type=(0, 0))
  42. for i in range(distmult_tf.num_types):
  43. distmult_tf.vars['relation_%d' % i] = \
  44. tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
  45. _common(distmult_torch, distmult_tf)
  46. def test_bilinear_decoder():
  47. bilinear_torch = decagon_pytorch.decode.cartesian.BilinearDecoder(input_dim=10,
  48. num_relation_types=7)
  49. bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7,
  50. edge_type=(0, 0))
  51. for i in range(bilinear_tf.num_types):
  52. bilinear_tf.vars['relation_%d' % i] = \
  53. tf.convert_to_tensor(bilinear_torch.relation[i].detach().numpy())
  54. _common(bilinear_torch, bilinear_tf)
  55. def test_inner_product_decoder():
  56. inner_torch = decagon_pytorch.decode.cartesian.InnerProductDecoder(input_dim=10,
  57. num_relation_types=7)
  58. inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7,
  59. edge_type=(0, 0))
  60. _common(inner_torch, inner_tf)