|
|
@@ -21,6 +21,7 @@ |
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from .convole import SparseMultiDGCA
|
|
|
|
|
|
|
|
|
|
|
|
class InputLayer(torch.nn.Module):
|
|
|
@@ -52,9 +53,48 @@ class InputLayer(torch.nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
class DecagonLayer(torch.nn.Module):
|
|
|
|
def __init__(self, data, **kwargs):
|
|
|
|
def __init__(self, data,
|
|
|
|
input_dim, output_dim,
|
|
|
|
keep_prob=1.,
|
|
|
|
rel_activation=lambda x: x,
|
|
|
|
layer_activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.rel_activation = rel_activation
|
|
|
|
self.layer_activation = layer_activation
|
|
|
|
self.convolutions = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
self.convolutions = {}
|
|
|
|
for key in self.data.relation_types.keys():
|
|
|
|
adjacency_matrices = \
|
|
|
|
self.data.get_adjacency_matrices(*key)
|
|
|
|
self.convolutions[key] = SparseMultiDGCA(self.input_dim,
|
|
|
|
self.output_dim, adjacency_matrices,
|
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
|
|
|
|
|
def __call__(self, previous_layer):
|
|
|
|
pass
|
|
|
|
# for node_type_row, node_type_col in enumerate(self.data.node_
|
|
|
|
# if rt.node_type_row == i or rt.node_type_col == i:
|
|
|
|
|
|
|
|
def __call__(self, prev_layer_repr):
|
|
|
|
new_layer_repr = []
|
|
|
|
for i, nt in enumerate(self.data.node_types):
|
|
|
|
new_repr = []
|
|
|
|
for key in self.data.relation_types.keys():
|
|
|
|
nt_row, nt_col = key
|
|
|
|
if nt_row != i and nt_col != i:
|
|
|
|
continue
|
|
|
|
if nt_row == i:
|
|
|
|
x = prev_layer_repr[nt_col]
|
|
|
|
else:
|
|
|
|
x = prev_layer_repr[nt_row]
|
|
|
|
conv = self.convolutions[key]
|
|
|
|
new_repr.append(conv(x))
|
|
|
|
new_repr = sum(new_repr)
|
|
|
|
new_layer_repr.append(new_repr)
|
|
|
|
return new_layer_repr
|