IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

24 lines
650B

  1. from icosagon.weights import init_glorot
  2. import torch
  3. import numpy as np
  4. def test_init_glorot_01():
  5. torch.random.manual_seed(0)
  6. res = init_glorot(10, 20)
  7. torch.random.manual_seed(0)
  8. rnd = torch.rand((10, 20))
  9. init_range = np.sqrt(6.0 / 30)
  10. expected = -init_range + 2 * init_range * rnd
  11. assert torch.all(res == expected)
  12. def test_init_glorot_02():
  13. torch.random.manual_seed(0)
  14. res = init_glorot(20, 10)
  15. torch.random.manual_seed(0)
  16. rnd = torch.rand((20, 10))
  17. init_range = np.sqrt(6.0 / 30)
  18. expected = -init_range + 2 * init_range * rnd
  19. assert torch.all(res == expected)