|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- #
- # This module implements a single layer of the Decagon
- # model. This is going to be already quite complex, as
- # we will be using all the graph convolutional building
- # blocks.
- #
- # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
- # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
- #
- # N{r}^{i} - set of neighbors of node i under relation r
- # W_{r}^(k) - relation-type specific weight matrix
- # h_{i}^(k) - hidden state of node i in layer k
- # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
- # of the representation in k-th layer
- # ϕ - activation function
- # c_{r}^{ij} - normalization constants
- # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
- # c_{r}^{i} - normalization constants
- # c_{r}^{i} = 1/|N_{r}^{i}|
- #
-
-
- import torch
- from .convolve import DropoutGraphConvActivation
- from .data import Data
- from typing import List, \
- Union, \
- Callable
- from collections import defaultdict
-
-
- class Layer(torch.nn.Module):
- def __init__(self,
- output_dim: Union[int, List[int]],
- is_sparse: bool,
- **kwargs) -> None:
- super().__init__(**kwargs)
- self.output_dim = output_dim
- self.is_sparse = is_sparse
-
-
- class InputLayer(Layer):
- def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
- output_dim = output_dim or \
- list(map(lambda a: a.count, data.node_types))
- if not isinstance(output_dim, list):
- output_dim = [output_dim,] * len(data.node_types)
-
- super().__init__(output_dim, is_sparse=False, **kwargs)
- self.data = data
- self.node_reps = None
- self.build()
-
- def build(self) -> None:
- self.node_reps = []
- for i, nt in enumerate(self.data.node_types):
- reps = torch.rand(nt.count, self.output_dim[i])
- reps = torch.nn.Parameter(reps)
- self.register_parameter('node_reps[%d]' % i, reps)
- self.node_reps.append(reps)
-
- def forward(self) -> List[torch.nn.Parameter]:
- return self.node_reps
-
- def __repr__(self) -> str:
- s = ''
- s += 'GNN input layer with output_dim: %s\n' % self.output_dim
- s += ' # of node types: %d\n' % len(self.data.node_types)
- for nt in self.data.node_types:
- s += ' - %s (%d)\n' % (nt.name, nt.count)
- return s.strip()
-
-
- class OneHotInputLayer(Layer):
- def __init__(self, data: Data, **kwargs) -> None:
- output_dim = [ a.count for a in data.node_types ]
- super().__init__(output_dim, is_sparse=True, **kwargs)
- self.data = data
- self.node_reps = None
- self.build()
-
- def build(self) -> None:
- self.node_reps = []
- for i, nt in enumerate(self.data.node_types):
- reps = torch.eye(nt.count).to_sparse()
- reps = torch.nn.Parameter(reps)
- self.register_parameter('node_reps[%d]' % i, reps)
- self.node_reps.append(reps)
-
- def forward(self) -> List[torch.nn.Parameter]:
- return self.node_reps
-
- def __repr__(self) -> str:
- s = ''
- s += 'One-hot GNN input layer\n'
- s += ' # of node types: %d\n' % len(self.data.node_types)
- for nt in self.data.node_types:
- s += ' - %s (%d)\n' % (nt.name, nt.count)
- return s.strip()
-
-
- class DecagonLayer(Layer):
- def __init__(self,
- data: Data,
- previous_layer: Layer,
- output_dim: Union[int, List[int]],
- keep_prob: float = 1.,
- rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
- **kwargs):
- if not isinstance(output_dim, list):
- output_dim = [ output_dim ] * len(data.node_types)
- super().__init__(output_dim, is_sparse=False, **kwargs)
- self.data = data
- self.previous_layer = previous_layer
- self.input_dim = previous_layer.output_dim
- self.keep_prob = keep_prob
- self.rel_activation = rel_activation
- self.layer_activation = layer_activation
- self.next_layer_repr = None
- self.build()
-
- def build(self):
- self.next_layer_repr = defaultdict(list)
-
- for (nt_row, nt_col), relation_types in self.data.relation_types.items():
- for rel in relation_types:
- conv = DropoutGraphConvActivation(self.input_dim[nt_col],
- self.output_dim[nt_row], rel.adjacency_matrix,
- self.keep_prob, self.rel_activation)
- self.next_layer_repr[nt_row].append((conv, nt_col))
-
- if nt_row == nt_col:
- continue
-
- conv = DropoutGraphConvActivation(self.input_dim[nt_row],
- self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
- self.keep_prob, self.rel_activation)
- self.next_layer_repr[nt_col].append((conv, nt_row))
-
- def __call__(self):
- prev_layer_repr = self.previous_layer()
- next_layer_repr = [None] * len(self.data.node_types)
- print('next_layer_repr:', next_layer_repr)
- for i in range(len(self.data.node_types)):
- next_layer_repr[i] = [
- conv(prev_layer_repr[neighbor_type]) \
- for (conv, neighbor_type) in \
- self.next_layer_repr[i]
- ]
- next_layer_repr[i] = sum(next_layer_repr[i])
- next_layer_repr[i] = torch.nn.functional.normalize(next_layer_repr[i], p=2, dim=1)
-
- print('next_layer_repr:', next_layer_repr)
- # next_layer_repr = list(map(sum, next_layer_repr))
- return next_layer_repr
|