IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Work on declayer.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
1868014ca3
共有 3 個文件被更改,包括 114 次插入3 次删除
  1. +103
    -2
      src/icosagon/declayer.py
  2. +1
    -1
      tests/icosagon/test_sampling.py
  3. +10
    -0
      tests/icosagon/test_weights.py

+ 103
- 2
src/icosagon/declayer.py 查看文件

@@ -1,2 +1,103 @@
# from .layer import DecagonLayer
# from .input import OneHotInputLayer
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
import torch
from .data import Data
from .trainprep import PreparedData, \
TrainValTest
from typing import Type, \
List, \
Callable, \
Union, \
Dict, \
Tuple
from .decode import DEDICOMDecoder
class DecodeLayer(torch.nn.Module):
def __init__(self,
input_dim: List[int],
data: Union[Data, PreparedData],
keep_prob: float = 1.,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
**kwargs) -> None:
super().__init__(**kwargs)
assert all([ a == input_dim[0] \
for a in input_dim ])
self.input_dim = input_dim
self.output_dim = 1
self.data = data
self.keep_prob = keep_prob
self.activation = activation
self.decoder_class = decoder_class
self.decoders = None
self.build()
def build(self) -> None:
self.decoders = {}
n = len(self.data.node_types)
for node_type_row in range(n):
if node_type_row not in relation_types:
continue
for node_type_column in range(n):
if node_type_column not in relation_types[node_type_row]:
continue
rels = relation_types[node_type_row][node_type_column]
if len(rels) == 0:
continue
if isinstance(self.decoder_class, dict):
if (node_type_row, node_type_column) in self.decoder_class:
decoder_class = self.decoder_class[node_type_row, node_type_column]
elif (node_type_column, node_type_row) in self.decoder_class:
decoder_class = self.decoder_class[node_type_column, node_type_row]
else:
raise KeyError('Decoder not specified for edge type: %s -- %s' % (
self.data.node_types[node_type_row].name,
self.data.node_types[node_type_column].name))
else:
decoder_class = self.decoder_class
self.decoders[node_type_row, node_type_column] = \
decoder_class(self.input_dim,
num_relation_types = len(rels),
drop_prob = 1. - self.keep_prob,
activation = self.activation)
def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest:
# n = len(self.data.node_types)
# relation_types = self.data.relation_types
# for node_type_row in range(n):
# if node_type_row not in relation_types:
# continue
#
# for node_type_column in range(n):
# if node_type_column not in relation_types[node_type_row]:
# continue
#
# rels = relation_types[node_type_row][node_type_column]
#
# for mode in ['train', 'val', 'test']:
# getattr(relation_types[node_type_row][node_type_column].edges_pos, mode)
# getattr(self.data.edges_neg, mode)
# last_layer[]
res = {}
for (node_type_row, node_type_column), dec in self.decoders.items():
inputs_row = last_layer_repr[node_type_row]
inputs_column = last_layer_repr[node_type_column]
pred_adj_matrices = dec(inputs_row, inputs_col)
res[node_type_row, node_type_col] = pred_adj_matrices
return res

+ 1
- 1
tests/icosagon/test_sampling.py 查看文件

@@ -132,7 +132,7 @@ def test_unigram_03():
counts_tf = defaultdict(list)
counts_torch = defaultdict(list)
for i in range(100):
for i in range(10):
neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
true_classes=true_classes_tf,
num_true=num_true,


+ 10
- 0
tests/icosagon/test_weights.py 查看文件

@@ -11,3 +11,13 @@ def test_init_glorot_01():
init_range = np.sqrt(6.0 / 30)
expected = -init_range + 2 * init_range * rnd
assert torch.all(res == expected)
def test_init_glorot_02():
torch.random.manual_seed(0)
res = init_glorot(20, 10)
torch.random.manual_seed(0)
rnd = torch.rand((20, 10))
init_range = np.sqrt(6.0 / 30)
expected = -init_range + 2 * init_range * rnd
assert torch.all(res == expected)

Loading…
取消
儲存