from .layer import Layer import torch from ..convolve import DropoutGraphConvActivation from ..data import Data from typing import List, \ Union, \ Callable from collections import defaultdict class DecagonLayer(Layer): def __init__(self, data: Data, previous_layer: Layer, output_dim: Union[int, List[int]], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): if not isinstance(output_dim, list): output_dim = [ output_dim ] * len(data.node_types) super().__init__(output_dim, is_sparse=False, **kwargs) self.data = data self.previous_layer = previous_layer self.input_dim = previous_layer.output_dim self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation self.next_layer_repr = None self.build() def build(self): self.next_layer_repr = defaultdict(list) for (nt_row, nt_col), relation_types in self.data.relation_types.items(): row_convs = [] col_convs = [] for rel in relation_types: conv = DropoutGraphConvActivation(self.input_dim[nt_col], self.output_dim[nt_row], rel.adjacency_matrix, self.keep_prob, self.rel_activation) row_convs.append(conv) if nt_row == nt_col: continue conv = DropoutGraphConvActivation(self.input_dim[nt_row], self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), self.keep_prob, self.rel_activation) col_convs.append(conv) self.next_layer_repr[nt_row].append((row_convs, nt_col)) if nt_row == nt_col: continue self.next_layer_repr[nt_col].append((col_convs, nt_row)) def __call__(self): prev_layer_repr = self.previous_layer() next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] print('next_layer_repr:', next_layer_repr) for i in range(len(self.data.node_types)): for convs, neighbor_type in self.next_layer_repr[i]: convs = [ conv(prev_layer_repr[neighbor_type]) \ for conv in convs ] convs = sum(convs) convs = torch.nn.functional.normalize(convs, p=2, dim=1) next_layer_repr[i].append(convs) next_layer_repr[i] = sum(next_layer_repr[i]) next_layer_repr[i] = self.layer_activation(next_layer_repr[i]) print('next_layer_repr:', next_layer_repr) return next_layer_repr