import torch from .convolve import DropoutGraphConvActivation from .data import Data from .trainprep import PreparedData from typing import List, \ Union, \ Callable from collections import defaultdict from dataclasses import dataclass @dataclass class Convolutions(object): node_type_column: int convolutions: List[DropoutGraphConvActivation] class DecagonLayer(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) if not isinstance(input_dim, list): raise ValueError('input_dim must be a list') if not output_dim: raise ValueError('output_dim must be specified') if not isinstance(output_dim, list): output_dim = [output_dim] * len(data.node_types) if not isinstance(data, Data) and not isinstance(data, PreparedData): raise ValueError('data must be of type Data or PreparedData') self.input_dim = input_dim self.output_dim = output_dim self.data = data self.keep_prob = float(keep_prob) self.rel_activation = rel_activation self.layer_activation = layer_activation self.is_sparse = False self.next_layer_repr = None self.build() def build(self): n = len(self.data.node_types) rel_types = self.data.relation_types self.next_layer_repr = [ [] for _ in range(n) ] for node_type_row in range(n): if node_type_row not in rel_types: continue for node_type_column in range(n): if node_type_column not in rel_types[node_type_row]: continue rels = rel_types[node_type_row][node_type_column] if len(rels) == 0: continue convolutions = [] for r in rels: conv = DropoutGraphConvActivation(self.input_dim[node_type_column], self.output_dim[node_type_row], r.adjacency_matrix, self.keep_prob, self.rel_activation) convolutions.append(conv) self.next_layer_repr[node_type_row].append( Convolutions(node_type_column, convolutions)) def __call__(self, prev_layer_repr): next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] n = len(self.data.node_types) for node_type_row in range(n): for convolutions in self.next_layer_repr[node_type_row]: repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \ for conv in convolutions.convolutions ] repr_ = sum(repr_) repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1) next_layer_repr[node_type_row].append(repr_) next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) return next_layer_repr