|
|
@@ -22,51 +22,58 @@ |
|
|
|
|
|
|
|
import torch
|
|
|
|
from .convolve import SparseMultiDGCA
|
|
|
|
from .data import Data
|
|
|
|
from typing import List, Union
|
|
|
|
|
|
|
|
|
|
|
|
class InputLayer(torch.nn.Module):
|
|
|
|
def __init__(self, data, dimensionality=None, **kwargs):
|
|
|
|
class Layer(torch.nn.Module):
|
|
|
|
def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None:
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.data = data
|
|
|
|
dimensionality = dimensionality or \
|
|
|
|
self.output_dim = output_dim
|
|
|
|
|
|
|
|
|
|
|
|
class InputLayer(Layer):
|
|
|
|
def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
|
|
|
|
output_dim = output_dim or \
|
|
|
|
list(map(lambda a: a.count, data.node_types))
|
|
|
|
if not isinstance(dimensionality, list):
|
|
|
|
dimensionality = [dimensionality,] * len(self.data.node_types)
|
|
|
|
self.dimensionality = dimensionality
|
|
|
|
if not isinstance(output_dim, list):
|
|
|
|
output_dim = [output_dim,] * len(data.node_types)
|
|
|
|
|
|
|
|
super().__init__(output_dim, **kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.node_reps = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
def build(self) -> None:
|
|
|
|
self.node_reps = []
|
|
|
|
for i, nt in enumerate(self.data.node_types):
|
|
|
|
reps = torch.rand(nt.count, self.dimensionality[i])
|
|
|
|
reps = torch.rand(nt.count, self.output_dim[i])
|
|
|
|
reps = torch.nn.Parameter(reps)
|
|
|
|
self.register_parameter('node_reps[%d]' % i, reps)
|
|
|
|
self.node_reps.append(reps)
|
|
|
|
|
|
|
|
def forward(self):
|
|
|
|
def forward(self) -> List[torch.nn.Parameter]:
|
|
|
|
return self.node_reps
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
s = ''
|
|
|
|
s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality
|
|
|
|
s += 'GNN input layer with output_dim: %s\n' % self.output_dim
|
|
|
|
s += ' # of node types: %d\n' % len(self.data.node_types)
|
|
|
|
for nt in self.data.node_types:
|
|
|
|
s += ' - %s (%d)\n' % (nt.name, nt.count)
|
|
|
|
return s.strip()
|
|
|
|
|
|
|
|
|
|
|
|
class DecagonLayer(torch.nn.Module):
|
|
|
|
def __init__(self, data,
|
|
|
|
class DecagonLayer(Layer):
|
|
|
|
def __init__(self, data: Data,
|
|
|
|
input_dim, output_dim,
|
|
|
|
keep_prob=1.,
|
|
|
|
rel_activation=lambda x: x,
|
|
|
|
layer_activation=torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
super().__init__(output_dim, **kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.rel_activation = rel_activation
|
|
|
|
self.layer_activation = layer_activation
|
|
|
|