IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Use torch.nn.Parameter(List) in decode.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
25e05cf1c2
共有 1 個文件被更改,包括 10 次插入10 次删除
  1. +10
    -10
      src/icosagon/decode.py

+ 10
- 10
src/icosagon/decode.py 查看文件

@@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.global_interaction = init_glorot(input_dim, input_dim)
self.local_variation = [
torch.flatten(init_glorot(input_dim, 1)) \
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
self.local_variation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
@@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
torch.flatten(init_glorot(input_dim, 1)) \
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
@@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
init_glorot(input_dim, input_dim) \
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)


Loading…
取消
儲存