import torch import numpy as np def init_glorot(input_dim, output_dim): """Create a weight variable with Glorot & Bengio (AISTATS 2010) initialization. """ init_range = np.sqrt(6.0 / (input_dim + output_dim)) initial = -init_range + 2 * init_range * \ torch.rand(( input_dim, output_dim ), dtype=torch.float32) initial = initial.requires_grad_(True) return initial