import torch from .convolve import DropoutGraphConvActivation from .data import Data from .trainprep import PreparedData from typing import List, \ Union, \ Callable from collections import defaultdict from dataclasses import dataclass class Convolutions(torch.nn.Module): node_type_column: int convolutions: torch.nn.ModuleList # [DropoutGraphConvActivation] def __init__(self, node_type_column: int, convolutions: torch.nn.ModuleList, **kwargs): super().__init__(**kwargs) self.node_type_column = node_type_column self.convolutions = convolutions class DecagonLayer(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) if not isinstance(input_dim, list): raise ValueError('input_dim must be a list') if not output_dim: raise ValueError('output_dim must be specified') if not isinstance(output_dim, list): output_dim = [output_dim] * len(data.node_types) if not isinstance(data, Data) and not isinstance(data, PreparedData): raise ValueError('data must be of type Data or PreparedData') self.input_dim = input_dim self.output_dim = output_dim self.data = data self.keep_prob = float(keep_prob) self.rel_activation = rel_activation self.layer_activation = layer_activation self.is_sparse = False self.next_layer_repr = None self.build() def build_fam_one_node_type(self, fam): convolutions = torch.nn.ModuleList() for r in fam.relation_types: conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], self.output_dim[fam.node_type_row], r.adjacency_matrix, self.keep_prob, self.rel_activation) convolutions.append(conv) self.next_layer_repr[fam.node_type_row].append( Convolutions(fam.node_type_column, convolutions)) def build_fam_two_node_types(self, fam) -> None: convolutions_row = torch.nn.ModuleList() convolutions_column = torch.nn.ModuleList() for r in fam.relation_types: if r.adjacency_matrix is not None: conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], self.output_dim[fam.node_type_row], r.adjacency_matrix, self.keep_prob, self.rel_activation) convolutions_row.append(conv) if r.adjacency_matrix_backward is not None: conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row], self.output_dim[fam.node_type_column], r.adjacency_matrix_backward, self.keep_prob, self.rel_activation) convolutions_column.append(conv) self.next_layer_repr[fam.node_type_row].append( Convolutions(fam.node_type_column, convolutions_row)) self.next_layer_repr[fam.node_type_column].append( Convolutions(fam.node_type_row, convolutions_column)) def build_family(self, fam) -> None: if fam.node_type_row == fam.node_type_column: self.build_fam_one_node_type(fam) else: self.build_fam_two_node_types(fam) def build(self): self.next_layer_repr = torch.nn.ModuleList([ torch.nn.ModuleList() for _ in range(len(self.data.node_types)) ]) for fam in self.data.relation_families: self.build_family(fam) def __call__(self, prev_layer_repr): next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] n = len(self.data.node_types) for node_type_row in range(n): for convolutions in self.next_layer_repr[node_type_row]: repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \ for conv in convolutions.convolutions ] repr_ = sum(repr_) repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1) next_layer_repr[node_type_row].append(repr_) if len(next_layer_repr[node_type_row]) == 0: next_layer_repr[node_type_row] = torch.zeros(self.output_dim[node_type_row]) else: next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row]) next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row]) return next_layer_repr