|
|
@@ -23,7 +23,9 @@ |
|
|
|
import torch
|
|
|
|
from .convolve import SparseMultiDGCA
|
|
|
|
from .data import Data
|
|
|
|
from typing import List, Union
|
|
|
|
from typing import List, \
|
|
|
|
Union, \
|
|
|
|
Callable
|
|
|
|
|
|
|
|
|
|
|
|
class Layer(torch.nn.Module):
|
|
|
@@ -65,15 +67,20 @@ class InputLayer(Layer): |
|
|
|
|
|
|
|
|
|
|
|
class DecagonLayer(Layer):
|
|
|
|
def __init__(self, data: Data,
|
|
|
|
input_dim, output_dim,
|
|
|
|
keep_prob=1.,
|
|
|
|
rel_activation=lambda x: x,
|
|
|
|
layer_activation=torch.nn.functional.relu,
|
|
|
|
def __init__(self,
|
|
|
|
data: Data,
|
|
|
|
previous_layer: Layer,
|
|
|
|
output_dim: Union[int, List[int]],
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
if not isinstance(output_dim, list):
|
|
|
|
output_dim = [ output_dim ] * len(data.node_types)
|
|
|
|
super().__init__(output_dim, **kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.previous_layer = previous_layer
|
|
|
|
self.input_dim = previous_layer.output_dim
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.rel_activation = rel_activation
|
|
|
|
self.layer_activation = layer_activation
|
|
|
|