IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ver código fonte

Use torch.nn.Parameter(List) in decode.

master
Stanislaw Adaszewski 4 anos atrás
pai
commit
25e05cf1c2
1 arquivos alterados com 10 adições e 10 exclusões
  1. +10
    -10
      src/icosagon/decode.py

+ 10
- 10
src/icosagon/decode.py Ver arquivo

@@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.global_interaction = init_glorot(input_dim, input_dim)
self.local_variation = [
torch.flatten(init_glorot(input_dim, 1)) \
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
self.local_variation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
@@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
torch.flatten(init_glorot(input_dim, 1)) \
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
@@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
init_glorot(input_dim, input_dim) \
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)


Carregando…
Cancelar
Salvar