|
@@ -44,7 +44,7 @@ def convert_decoder(dec): |
|
|
local_variation = torch.eye(dec.input_dim, dec.input_dim)
|
|
|
local_variation = torch.eye(dec.input_dim, dec.input_dim)
|
|
|
local_variation = [ local_variation ] * dec.num_relation_types
|
|
|
local_variation = [ local_variation ] * dec.num_relation_types
|
|
|
else:
|
|
|
else:
|
|
|
raise TypeError('Unknown decoder type in covert_decoder()')
|
|
|
|
|
|
|
|
|
raise TypeError('Unknown decoder type in convert_decoder()')
|
|
|
|
|
|
|
|
|
if not isinstance(local_variation, torch.Tensor):
|
|
|
if not isinstance(local_variation, torch.Tensor):
|
|
|
local_variation = map(lambda a: a.view(1, *a.shape), local_variation)
|
|
|
local_variation = map(lambda a: a.view(1, *a.shape), local_variation)
|
|
|