|
|
@@ -60,11 +60,11 @@ class OneHotInputLayer(torch.nn.Module): |
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self) -> None:
|
|
|
|
self.node_reps = []
|
|
|
|
self.node_reps = torch.nn.ParameterList()
|
|
|
|
for i, nt in enumerate(self.data.node_types):
|
|
|
|
reps = torch.eye(nt.count).to_sparse()
|
|
|
|
reps = torch.nn.Parameter(reps)
|
|
|
|
self.register_parameter('node_reps[%d]' % i, reps)
|
|
|
|
reps = torch.nn.Parameter(reps, requires_grad=False)
|
|
|
|
# self.register_parameter('node_reps[%d]' % i, reps)
|
|
|
|
self.node_reps.append(reps)
|
|
|
|
|
|
|
|
def forward(self, x) -> List[torch.nn.Parameter]:
|
|
|
|