IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Start implementing DecagonLayer.

master
Stanislaw Adaszewski 4 years ago
parent
commit
f398b72267
1 changed files with 43 additions and 3 deletions
  1. +43
    -3
      src/decagon_pytorch/layer.py

+ 43
- 3
src/decagon_pytorch/layer.py View File

@@ -21,6 +21,7 @@
import torch
from .convole import SparseMultiDGCA
class InputLayer(torch.nn.Module):
@@ -52,9 +53,48 @@ class InputLayer(torch.nn.Module):
class DecagonLayer(torch.nn.Module):
def __init__(self, data, **kwargs):
def __init__(self, data,
input_dim, output_dim,
keep_prob=1.,
rel_activation=lambda x: x,
layer_activation=torch.nn.functional.relu,
**kwargs):
super().__init__(**kwargs)
self.data = data
self.input_dim = input_dim
self.output_dim = output_dim
self.keep_prob = keep_prob
self.rel_activation = rel_activation
self.layer_activation = layer_activation
self.convolutions = None
self.build()
def build(self):
self.convolutions = {}
for key in self.data.relation_types.keys():
adjacency_matrices = \
self.data.get_adjacency_matrices(*key)
self.convolutions[key] = SparseMultiDGCA(self.input_dim,
self.output_dim, adjacency_matrices,
self.keep_prob, self.rel_activation)
def __call__(self, previous_layer):
pass
# for node_type_row, node_type_col in enumerate(self.data.node_
# if rt.node_type_row == i or rt.node_type_col == i:
def __call__(self, prev_layer_repr):
new_layer_repr = []
for i, nt in enumerate(self.data.node_types):
new_repr = []
for key in self.data.relation_types.keys():
nt_row, nt_col = key
if nt_row != i and nt_col != i:
continue
if nt_row == i:
x = prev_layer_repr[nt_col]
else:
x = prev_layer_repr[nt_row]
conv = self.convolutions[key]
new_repr.append(conv(x))
new_repr = sum(new_repr)
new_layer_repr.append(new_repr)
return new_layer_repr

Loading…
Cancel
Save