from icosagon.weights import init_glorot import torch import numpy as np def test_init_glorot_01(): torch.random.manual_seed(0) res = init_glorot(10, 20) torch.random.manual_seed(0) rnd = torch.rand((10, 20)) init_range = np.sqrt(6.0 / 30) expected = -init_range + 2 * init_range * rnd assert torch.all(res == expected)