diff --git a/tests/icosagon/test_weights.py b/tests/icosagon/test_weights.py new file mode 100644 index 0000000..7456076 --- /dev/null +++ b/tests/icosagon/test_weights.py @@ -0,0 +1,13 @@ +from icosagon.weights import init_glorot +import torch +import numpy as np + + +def test_init_glorot_01(): + torch.random.manual_seed(0) + res = init_glorot(10, 20) + torch.random.manual_seed(0) + rnd = torch.rand((10, 20)) + init_range = np.sqrt(6.0 / 30) + expected = -init_range + 2 * init_range * rnd + assert torch.all(res == expected)