from typing import List, \ Union, \ Callable from .data import Data, \ RelationFamily from .trainprep import PreparedData, \ PreparedRelationFamily import torch from .weights import init_glorot from .normalize import _sparse_coo_tensor import types def _sparse_diag_cat(matrices: List[torch.Tensor]): if len(matrices) == 0: raise ValueError('The list of matrices must be non-empty') if not all(m.is_sparse for m in matrices): raise ValueError('All matrices must be sparse') if not all(len(m.shape) == 2 for m in matrices): raise ValueError('All matrices must be 2D') indices = [] values = [] row_offset = 0 col_offset = 0 for m in matrices: ind = m._indices().clone() ind[0] += row_offset ind[1] += col_offset indices.append(ind) values.append(m._values()) row_offset += m.shape[0] col_offset += m.shape[1] indices = torch.cat(indices, dim=1) values = torch.cat(values) return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) def _cat(matrices: List[torch.Tensor]): if len(matrices) == 0: raise ValueError('Empty list passed to _cat()') n = sum(a.is_sparse for a in matrices) if n != 0 and n != len(matrices): raise ValueError('All matrices must have the same layout (dense or sparse)') if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): raise ValueError('All matrices must have the same dimensions apart from dimension 0') if not matrices[0].is_sparse: return torch.cat(matrices) total_rows = sum(a.shape[0] for a in matrices) indices = [] values = [] row_offset = 0 for a in matrices: ind = a._indices().clone() val = a._values() ind[0] += row_offset ind = ind.transpose(0, 1) indices.append(ind) values.append(val) row_offset += a.shape[0] indices = torch.cat(indices).transpose(0, 1) values = torch.cat(values) res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) return res class FastGraphConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, adjacency_matrices: List[torch.Tensor], keep_prob: float = 1., activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, **kwargs) -> None: super().__init__(**kwargs) in_channels = int(in_channels) out_channels = int(out_channels) if not isinstance(adjacency_matrices, list): raise TypeError('adjacency_matrices must be a list') if len(adjacency_matrices) == 0: raise ValueError('adjacency_matrices must not be empty') if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): raise TypeError('adjacency_matrices elements must be of class torch.Tensor') if not all(m.is_sparse for m in adjacency_matrices): raise ValueError('adjacency_matrices elements must be sparse') keep_prob = float(keep_prob) if not isinstance(activation, types.FunctionType): raise TypeError('activation must be a function') self.in_channels = in_channels self.out_channels = out_channels self.adjacency_matrices = adjacency_matrices self.keep_prob = keep_prob self.activation = activation self.num_row_nodes = len(adjacency_matrices[0]) self.num_relation_types = len(adjacency_matrices) self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) self.weights = torch.cat([ init_glorot(in_channels, out_channels) \ for _ in range(self.num_relation_types) ], dim=1) def forward(self, x) -> torch.Tensor: if self.keep_prob < 1.: x = dropout(x, self.keep_prob) res = torch.sparse.mm(x, self.weights) \ if x.is_sparse \ else torch.mm(x, self.weights) res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) res = torch.cat(res) res = torch.sparse.mm(self.adjacency_matrices, res) \ if self.adjacency_matrices.is_sparse \ else torch.mm(self.adjacency_matrices, res) res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) if self.activation is not None: res = self.activation(res) return res class FastConvLayer(torch.nn.Module): def __init__(self, input_dim: List[int], output_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) self._check_params(input_dim, output_dim, data, keep_prob, rel_activation, layer_activation) self.input_dim = input_dim self.output_dim = output_dim self.data = data self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation self.is_sparse = False self.next_layer_repr = None self.build() def build(self): self.next_layer_repr = torch.nn.ModuleList([ torch.nn.ModuleList() \ for _ in range(len(self.data.node_types)) ]) for fam in self.data.relation_families: self.build_family(fam) def build_family(self, fam) -> None: if fam.node_type_row == fam.node_type_column: self.build_fam_one_node_type(fam) else: self.build_fam_two_node_types(fam) def build_fam_one_node_type(self, fam) -> None: adjacency_matrices = [ r.adjacency_matrix \ for r in fam.relation_types ] conv = FastGraphConv(self.input_dim[fam.node_type_column], self.output_dim[fam.node_type_row], adjacency_matrices, self.keep_prob, self.rel_activation) conv.input_node_type = fam.node_type_column self.next_layer_repr[fam.node_type_row].append(conv) def build_fam_two_node_types(self, fam) -> None: adjacency_matrices = [ r.adjacency_matrix \ for r in fam.relation_types \ if r.adjacency_matrix is not None ] adjacency_matrices_backward = [ r.adjacency_matrix_backward \ for r in fam.relation_types \ if r.adjacency_matrix_backward is not None ] conv = FastGraphConv(self.input_dim[fam.node_type_column], self.output_dim[fam.node_type_row], adjacency_matrices, self.keep_prob, self.rel_activation) conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], self.output_dim[fam.node_type_column], adjacency_matrices_backward, self.keep_prob, self.rel_activation) conv.input_node_type = fam.node_type_column conv_backward.input_node_type = fam.node_type_row self.next_layer_repr[fam.node_type_row].append(conv) self.next_layer_repr[fam.node_type_column].append(conv_backward) def forward(self, prev_layer_repr): next_layer_repr = [ [] \ for _ in range(len(self.data.node_types)) ] for output_node_type in range(len(self.data.node_types)): for conv in self.next_layer_repr[output_node_type]: rep = conv(prev_layer_repr[conv.input_node_type]) rep = torch.sum(rep, dim=0) rep = torch.nn.functional.normalize(rep, p=2, dim=1) next_layer_repr[output_node_type].append(rep) if len(next_layer_repr[output_node_type]) == 0: next_layer_repr[output_node_type] = \ torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) else: next_layer_repr[output_node_type] = \ sum(next_layer_repr[output_node_type]) next_layer_repr[output_node_type] = \ self.layer_activation(next_layer_repr[output_node_type]) return next_layer_repr @staticmethod def _check_params(input_dim, output_dim, data, keep_prob, rel_activation, layer_activation): if not isinstance(input_dim, list): raise ValueError('input_dim must be a list') if not output_dim: raise ValueError('output_dim must be specified') if not isinstance(output_dim, list): output_dim = [output_dim] * len(data.node_types) if not isinstance(data, Data) and not isinstance(data, PreparedData): raise ValueError('data must be of type Data or PreparedData')