from typing import List, \ Union, \ Callable from .data import Data, \ RelationFamily from .trainprep import PreparedData, \ PreparedRelationFamily import torch from .weights import init_glorot from .normalize import _sparse_coo_tensor import types def _sparse_diag_cat(matrices: List[torch.Tensor]): if len(matrices) == 0: raise ValueError('The list of matrices must be non-empty') if not all(m.is_sparse for m in matrices): raise ValueError('All matrices must be sparse') if not all(len(m.shape) == 2 for m in matrices): raise ValueError('All matrices must be 2D') indices = [] values = [] row_offset = 0 col_offset = 0 for m in matrices: ind = m._indices().clone() ind[0] += row_offset ind[1] += col_offset indices.append(ind) values.append(m._values()) row_offset += m.shape[0] col_offset += m.shape[1] indices = torch.cat(indices, dim=1) values = torch.cat(values) return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) def _cat(matrices: List[torch.Tensor]): if len(matrices) == 0: raise ValueError('Empty list passed to _cat()') n = sum(a.is_sparse for a in matrices) if n != 0 and n != len(matrices): raise ValueError('All matrices must have the same layout (dense or sparse)') if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): raise ValueError('All matrices must have the same dimensions apart from dimension 0') if not matrices[0].is_sparse: return torch.cat(matrices) total_rows = sum(a.shape[0] for a in matrices) indices = [] values = [] row_offset = 0 for a in matrices: ind = a._indices().clone() val = a._values() ind[0] += row_offset ind = ind.transpose(0, 1) indices.append(ind) values.append(val) row_offset += a.shape[0] indices = torch.cat(indices).transpose(0, 1) values = torch.cat(values) res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) return res class FastGraphConv(torch.nn.Module): def __init__(self, in_channels: List[int], out_channels: List[int], data: Union[Data, PreparedData], relation_family: Union[RelationFamily, PreparedRelationFamily] keep_prob: float = 1., acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, **kwargs) -> None: in_channels = int(in_channels) out_channels = int(out_channels) if not isinstance(data, Data) and not isinstance(data, PreparedData): raise TypeError('data must be an instance of Data or PreparedData') if not isinstance(relation_family, RelationFamily) and \ not isinstance(relation_family, PreparedRelationFamily): raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily') keep_prob = float(keep_prob) if not isinstance(activation, types.FunctionType): raise TypeError('activation must be a function') n_nodes_row = data.node_types[relation_family.node_type_row].count n_nodes_column = data.node_types[relation_family.node_type_column].count self.in_channels = in_channels self.out_channels = out_channels self.data = data self.relation_family = relation_family self.keep_prob = keep_prob self.activation = activation self.weight = torch.cat([ init_glorot(in_channels, out_channels) \ for _ in range(len(relation_family.relation_types)) ], dim=1) self.weight_backward = torch.cat([ init_glorot(in_channels, out_channels) \ for _ in range(len(relation_family.relation_types)) ], dim=1) self.adjacency_matrix = _sparse_diag_cat([ rel.adjacency_matrix \ if rel.adjacency_matrix is not None \ else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \ for rel in relation_family.relation_types ]) self.adjacency_matrix_backward = _sparse_diag_cat([ rel.adjacency_matrix_backward \ if rel.adjacency_matrix_backward is not None \ else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \ for rel in relation_family.relation_types ]) def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: repr_row = prev_layer_repr[self.relation_family.node_type_row] repr_column = prev_layer_repr[self.relation_family.node_type_column] new_repr_row = torch.sparse.mm(repr_column, self.weight) \ if repr_column.is_sparse \ else torch.mm(repr_column, self.weight) new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \ if self.adjacency_matrix.is_sparse \ else torch.mm(self.adjacency_matrix, new_repr_row) new_repr_row = new_repr_row.view(len(self.relation_family.relation_types), len(repr_row), self.out_channels) new_repr_column = torch.sparse.mm(repr_row, self.weight) \ if repr_row.is_sparse \ else torch.mm(repr_row, self.weight) new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \ if self.adjacency_matrix_backward.is_sparse \ else torch.mm(self.adjacency_matrix_backward, new_repr_column) new_repr_column = new_repr_column.view(len(self.relation_family.relation_types), len(repr_column), self.out_channels) return (new_repr_row, new_repr_column) class FastConvLayer(torch.nn.Module): adjacency_matrix: List[torch.Tensor] adjacency_matrix_backward: List[torch.Tensor] weight: List[torch.Tensor] weight_backward: List[torch.Tensor] def __init__(self, input_dim: List[int], output_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, **kwargs): super().__init__(**kwargs) self._check_params(input_dim, output_dim, data, keep_prob, rel_activation, layer_activation) self.input_dim = input_dim self.output_dim = output_dim self.data = data self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation self.adjacency_matrix = None self.adjacency_matrix_backward = None self.weight = None self.weight_backward = None self.build() def build(self): self.adjacency_matrix = [] self.adjacency_matrix_backward = [] self.weight = [] self.weight_backward = [] for fam in self.data.relation_families: adj_mat = [ rel.adjacency_matrix \ for rel in fam.relation_types \ if rel.adjacency_matrix is not None ] adj_mat_back = [ rel.adjacency_matrix_backward \ for rel in fam.relation_types \ if rel.adjacency_matrix_backward is not None ] weight = [ init_glorot(self.input_dim[fam.node_type_column], self.output_dim[fam.node_type_row]) \ for _ in range(len(adj_mat)) ] weight_back = [ init_glorot(self.input_dim[fam.node_type_column], self.output_dim[fam.node_type_row]) \ for _ in range(len(adj_mat_back)) ] adj_mat = torch.cat(adj_mat) \ if len(adj_mat) > 0 \ else None adj_mat_back = torch.cat(adj_mat_back) \ if len(adj_mat_back) > 0 \ else None self.adjacency_matrix.append(adj_mat) self.adjacency_matrix_backward.append(adj_mat_back) self.weight.append(weight) self.weight_back.append(weight_back) def forward(self, prev_layer_repr): for i, fam in enumerate(self.data.relation_families): repr_row = prev_layer_repr[fam.node_type_row] repr_column = prev_layer_repr[fam.node_type_column] adj_mat = self.adjacency_matrix[i] adj_mat_back = self.adjacency_matrix_backward[i] if adj_mat is not None: x = dropout(repr_column, keep_prob=self.keep_prob) x = torch.sparse.mm(x, self.weight[i]) \ if x.is_sparse \ else torch.mm(x, self.weight[i]) x = torch.sparse.mm(adj_mat, repr_row) \ if adj_mat.is_sparse \ else torch.mm(adj_mat, repr_row) x = self.rel_activation(x) x = x.view(len(fam.relation_types), len(repr_row), -1) if adj_mat_back is not None: x = torch.sparse.mm(adj_mat_back, repr_row) \ if adj_mat_back.is_sparse \ else torch.mm(adj_mat_back, repr_row) @staticmethod def _check_params(input_dim, output_dim, data, keep_prob, rel_activation, layer_activation): if not isinstance(input_dim, list): raise ValueError('input_dim must be a list') if not output_dim: raise ValueError('output_dim must be specified') if not isinstance(output_dim, list): output_dim = [output_dim] * len(data.node_types) if not isinstance(data, Data) and not isinstance(data, PreparedData): raise ValueError('data must be of type Data or PreparedData')