|
12345678910111213 |
- import torch
- import numpy as np
-
-
- def init_glorot(input_dim, output_dim):
- """Create a weight variable with Glorot & Bengio (AISTATS 2010)
- initialization.
- """
- init_range = np.sqrt(6.0 / (input_dim + output_dim))
- initial = -init_range + 2 * init_range * \
- torch.rand(( input_dim, output_dim ), dtype=torch.float32)
- initial = initial.requires_grad_(True)
- return initial
|