|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- from typing import List, \
- Union, \
- Callable
- from .data import Data, \
- RelationFamily
- from .trainprep import PreparedData, \
- PreparedRelationFamily
- import torch
- from .weights import init_glorot
- from .normalize import _sparse_coo_tensor
- import types
- from .util import _sparse_diag_cat,
- _cat
-
-
-
-
-
- class FastGraphConv(torch.nn.Module):
- def __init__(self,
- in_channels: int,
- out_channels: int,
- adjacency_matrices: List[torch.Tensor],
- keep_prob: float = 1.,
- activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- **kwargs) -> None:
-
- super().__init__(**kwargs)
-
- in_channels = int(in_channels)
- out_channels = int(out_channels)
- if not isinstance(adjacency_matrices, list):
- raise TypeError('adjacency_matrices must be a list')
- if len(adjacency_matrices) == 0:
- raise ValueError('adjacency_matrices must not be empty')
- if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
- raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
- if not all(m.is_sparse for m in adjacency_matrices):
- raise ValueError('adjacency_matrices elements must be sparse')
- keep_prob = float(keep_prob)
- if not isinstance(activation, types.FunctionType):
- raise TypeError('activation must be a function')
-
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.adjacency_matrices = adjacency_matrices
- self.keep_prob = keep_prob
- self.activation = activation
-
- self.num_row_nodes = len(adjacency_matrices[0])
- self.num_relation_types = len(adjacency_matrices)
-
- self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
-
- self.weights = torch.cat([
- init_glorot(in_channels, out_channels) \
- for _ in range(self.num_relation_types)
- ], dim=1)
-
- def forward(self, x) -> torch.Tensor:
- if self.keep_prob < 1.:
- x = dropout(x, self.keep_prob)
- res = torch.sparse.mm(x, self.weights) \
- if x.is_sparse \
- else torch.mm(x, self.weights)
- res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
- res = torch.cat(res)
- res = torch.sparse.mm(self.adjacency_matrices, res) \
- if self.adjacency_matrices.is_sparse \
- else torch.mm(self.adjacency_matrices, res)
- res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
- if self.activation is not None:
- res = self.activation(res)
-
- return res
-
-
- class FastConvLayer(torch.nn.Module):
- def __init__(self,
- input_dim: List[int],
- output_dim: List[int],
- data: Union[Data, PreparedData],
- keep_prob: float = 1.,
- rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
- **kwargs):
-
- super().__init__(**kwargs)
-
- self._check_params(input_dim, output_dim, data, keep_prob,
- rel_activation, layer_activation)
-
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.data = data
- self.keep_prob = keep_prob
- self.rel_activation = rel_activation
- self.layer_activation = layer_activation
-
- self.is_sparse = False
- self.next_layer_repr = None
- self.build()
-
- def build(self):
- self.next_layer_repr = torch.nn.ModuleList([
- torch.nn.ModuleList() \
- for _ in range(len(self.data.node_types))
- ])
- for fam in self.data.relation_families:
- self.build_family(fam)
-
- def build_family(self, fam) -> None:
- if fam.node_type_row == fam.node_type_column:
- self.build_fam_one_node_type(fam)
- else:
- self.build_fam_two_node_types(fam)
-
- def build_fam_one_node_type(self, fam) -> None:
- adjacency_matrices = [
- r.adjacency_matrix \
- for r in fam.relation_types
- ]
- conv = FastGraphConv(self.input_dim[fam.node_type_column],
- self.output_dim[fam.node_type_row],
- adjacency_matrices,
- self.keep_prob,
- self.rel_activation)
- conv.input_node_type = fam.node_type_column
- self.next_layer_repr[fam.node_type_row].append(conv)
-
- def build_fam_two_node_types(self, fam) -> None:
- adjacency_matrices = [
- r.adjacency_matrix \
- for r in fam.relation_types \
- if r.adjacency_matrix is not None
- ]
-
- adjacency_matrices_backward = [
- r.adjacency_matrix_backward \
- for r in fam.relation_types \
- if r.adjacency_matrix_backward is not None
- ]
-
- conv = FastGraphConv(self.input_dim[fam.node_type_column],
- self.output_dim[fam.node_type_row],
- adjacency_matrices,
- self.keep_prob,
- self.rel_activation)
-
- conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
- self.output_dim[fam.node_type_column],
- adjacency_matrices_backward,
- self.keep_prob,
- self.rel_activation)
-
- conv.input_node_type = fam.node_type_column
- conv_backward.input_node_type = fam.node_type_row
-
- self.next_layer_repr[fam.node_type_row].append(conv)
- self.next_layer_repr[fam.node_type_column].append(conv_backward)
-
- def forward(self, prev_layer_repr):
- next_layer_repr = [ [] \
- for _ in range(len(self.data.node_types)) ]
- for output_node_type in range(len(self.data.node_types)):
- for conv in self.next_layer_repr[output_node_type]:
- rep = conv(prev_layer_repr[conv.input_node_type])
- rep = torch.sum(rep, dim=0)
- rep = torch.nn.functional.normalize(rep, p=2, dim=1)
- next_layer_repr[output_node_type].append(rep)
- if len(next_layer_repr[output_node_type]) == 0:
- next_layer_repr[output_node_type] = \
- torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
- else:
- next_layer_repr[output_node_type] = \
- sum(next_layer_repr[output_node_type])
- next_layer_repr[output_node_type] = \
- self.layer_activation(next_layer_repr[output_node_type])
- return next_layer_repr
-
- @staticmethod
- def _check_params(input_dim, output_dim, data, keep_prob,
- rel_activation, layer_activation):
-
- if not isinstance(input_dim, list):
- raise ValueError('input_dim must be a list')
-
- if not output_dim:
- raise ValueError('output_dim must be specified')
-
- if not isinstance(output_dim, list):
- output_dim = [output_dim] * len(data.node_types)
-
- if not isinstance(data, Data) and not isinstance(data, PreparedData):
- raise ValueError('data must be of type Data or PreparedData')
|