IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
소스 검색

Use torch.nn.Parameter(List) in decode.

master
Stanislaw Adaszewski 4 년 전
부모
커밋
25e05cf1c2
1개의 변경된 파일10개의 추가작업 그리고 10개의 파일을 삭제
  1. +10
    -10
      src/icosagon/decode.py

+ 10
- 10
src/icosagon/decode.py 파일 보기

@@ -20,11 +20,11 @@ class DEDICOMDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.global_interaction = init_glorot(input_dim, input_dim)
self.local_variation = [
torch.flatten(init_glorot(input_dim, 1)) \
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
self.local_variation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
@@ -53,10 +53,10 @@ class DistMultDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
torch.flatten(init_glorot(input_dim, 1)) \
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
@@ -83,10 +83,10 @@ class BilinearDecoder(torch.nn.Module):
self.keep_prob = keep_prob
self.activation = activation
self.relation = [
init_glorot(input_dim, input_dim) \
self.relation = torch.nn.ParameterList([
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
for _ in range(num_relation_types)
]
])
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)


불러오는 중...
취소
저장