|  | #
# This module implements a single layer of the Decagon
# model. This is going to be already quite complex, as
# we will be using all the graph convolutional building
# blocks.
#
# h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
#   W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
#
# N{r}^{i} - set of neighbors of node i under relation r
# W_{r}^(k) - relation-type specific weight matrix
# h_{i}^(k) - hidden state of node i in layer k
#   h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
#   of the representation in k-th layer
# ϕ - activation function
# c_{r}^{ij} - normalization constants
#   c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
# c_{r}^{i} - normalization constants
#   c_{r}^{i} = 1/|N_{r}^{i}|
#
import torch
from .convolve import SparseMultiDGCA
from .data import Data
from typing import List, Union
class Layer(torch.nn.Module):
    def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None:
        super().__init__(**kwargs)
        self.output_dim = output_dim
class InputLayer(Layer):
    def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
        output_dim = output_dim or \
            list(map(lambda a: a.count, data.node_types))
        if not isinstance(output_dim, list):
            output_dim = [output_dim,] * len(data.node_types)
        super().__init__(output_dim, **kwargs)
        self.data = data
        self.node_reps = None
        self.build()
    def build(self) -> None:
        self.node_reps = []
        for i, nt in enumerate(self.data.node_types):
            reps = torch.rand(nt.count, self.output_dim[i])
            reps = torch.nn.Parameter(reps)
            self.register_parameter('node_reps[%d]' % i, reps)
            self.node_reps.append(reps)
    def forward(self) -> List[torch.nn.Parameter]:
        return self.node_reps
    def __repr__(self) -> str:
        s = ''
        s += 'GNN input layer with output_dim: %s\n' % self.output_dim
        s += '  # of node types: %d\n' % len(self.data.node_types)
        for nt in self.data.node_types:
            s += '    - %s (%d)\n' % (nt.name, nt.count)
        return s.strip()
class DecagonLayer(Layer):
    def __init__(self, data: Data,
        input_dim, output_dim,
        keep_prob=1.,
        rel_activation=lambda x: x,
        layer_activation=torch.nn.functional.relu,
        **kwargs):
        super().__init__(output_dim, **kwargs)
        self.data = data
        self.input_dim = input_dim
        self.keep_prob = keep_prob
        self.rel_activation = rel_activation
        self.layer_activation = layer_activation
        self.convolutions = None
        self.build()
    def build(self):
        self.convolutions = {}
        for key in self.data.relation_types.keys():
            adjacency_matrices = \
                self.data.get_adjacency_matrices(*key)
            self.convolutions[key] = SparseMultiDGCA(self.input_dim,
                self.output_dim, adjacency_matrices,
                self.keep_prob, self.rel_activation)
        # for node_type_row, node_type_col in enumerate(self.data.node_
        #    if rt.node_type_row == i or rt.node_type_col == i:
    def __call__(self, prev_layer_repr):
        new_layer_repr = []
        for i, nt in enumerate(self.data.node_types):
            new_repr = []
            for key in self.data.relation_types.keys():
                nt_row, nt_col = key
                if nt_row != i and nt_col != i:
                    continue
                if nt_row == i:
                    x = prev_layer_repr[nt_col]
                else:
                    x = prev_layer_repr[nt_row]
                conv = self.convolutions[key]
                new_repr.append(conv(x))
            new_repr = sum(new_repr)
            new_layer_repr.append(new_repr)
        return new_layer_repr
 |