IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_normalize.py 919B

4 年之前
12345678910111213141516171819202122232425
  1. import decagon_pytorch.normalize
  2. import decagon.deep.minibatch
  3. import numpy as np
  4. def test_normalize_adjacency_matrix_square():
  5. mx = np.random.rand(10, 10)
  6. mx[mx < .5] = 0
  7. mx = np.ceil(mx)
  8. res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx)
  9. res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx)
  10. assert len(res_torch) == len(res_tf)
  11. for i in range(len(res_torch)):
  12. assert np.all(res_torch[i] == res_tf[i])
  13. def test_normalize_adjacency_matrix_nonsquare():
  14. mx = np.random.rand(5, 10)
  15. mx[mx < .5] = 0
  16. mx = np.ceil(mx)
  17. res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx)
  18. res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx)
  19. assert len(res_torch) == len(res_tf)
  20. for i in range(len(res_torch)):
  21. assert np.all(res_torch[i] == res_tf[i])