| @@ -0,0 +1,110 @@ | |||||
| from typing import List, \ | |||||
| Union, \ | |||||
| Callable | |||||
| from .data import Data | |||||
| from .trainprep import PreparedData | |||||
| import torch | |||||
| from .weights import init_glorot | |||||
| class FastConvLayer(torch.nn.Module): | |||||
| adjacency_matrix: List[torch.Tensor] | |||||
| adjacency_matrix_backward: List[torch.Tensor] | |||||
| weight: List[torch.Tensor] | |||||
| weight_backward: List[torch.Tensor] | |||||
| def __init__(self, | |||||
| input_dim: List[int], | |||||
| output_dim: List[int], | |||||
| data: Union[Data, PreparedData], | |||||
| keep_prob: float = 1., | |||||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x | |||||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
| **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self._check_params(input_dim, output_dim, data, keep_prob, | |||||
| rel_activation, layer_activation) | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.data = data | |||||
| self.keep_prob = keep_prob | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.adjacency_matrix = None | |||||
| self.adjacency_matrix_backward = None | |||||
| self.weight = None | |||||
| self.weight_backward = None | |||||
| self.build() | |||||
| def build(self): | |||||
| self.adjacency_matrix = [] | |||||
| self.adjacency_matrix_backward = [] | |||||
| self.weight = [] | |||||
| self.weight_backward = [] | |||||
| for fam in self.data.relation_families: | |||||
| adj_mat = [ rel.adjacency_matrix \ | |||||
| for rel in fam.relation_types \ | |||||
| if rel.adjacency_matrix is not None ] | |||||
| adj_mat_back = [ rel.adjacency_matrix_backward \ | |||||
| for rel in fam.relation_types \ | |||||
| if rel.adjacency_matrix_backward is not None ] | |||||
| weight = [ init_glorot(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row]) \ | |||||
| for _ in range(len(adj_mat)) ] | |||||
| weight_back = [ init_glorot(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row]) \ | |||||
| for _ in range(len(adj_mat_back)) ] | |||||
| adj_mat = torch.cat(adj_mat) \ | |||||
| if len(adj_mat) > 0 \ | |||||
| else None | |||||
| adj_mat_back = torch.cat(adj_mat_back) \ | |||||
| if len(adj_mat_back) > 0 \ | |||||
| else None | |||||
| self.adjacency_matrix.append(adj_mat) | |||||
| self.adjacency_matrix_backward.append(adj_mat_back) | |||||
| self.weight.append(weight) | |||||
| self.weight_back.append(weight_back) | |||||
| def forward(self, prev_layer_repr): | |||||
| for i, fam in enumerate(self.data.relation_families): | |||||
| repr_row = prev_layer_repr[fam.node_type_row] | |||||
| repr_column = prev_layer_repr[fam.node_type_column] | |||||
| adj_mat = self.adjacency_matrix[i] | |||||
| adj_mat_back = self.adjacency_matrix_backward[i] | |||||
| if adj_mat is not None: | |||||
| x = dropout(repr_column, keep_prob=self.keep_prob) | |||||
| x = torch.sparse.mm(x, self.weight[i]) \ | |||||
| if x.is_sparse \ | |||||
| else torch.mm(x, self.weight[i]) | |||||
| x = torch.sparse.mm(adj_mat, repr_row) \ | |||||
| if adj_mat.is_sparse \ | |||||
| else torch.mm(adj_mat, repr_row) | |||||
| x = self.rel_activation(x) | |||||
| x = x.view(len(fam.relation_types), len(repr_row), -1) | |||||
| if adj_mat_back is not None: | |||||
| x = torch.sparse.mm(adj_mat_back, repr_row) \ | |||||
| if adj_mat_back.is_sparse \ | |||||
| else torch.mm(adj_mat_back, repr_row) | |||||
| @staticmethod | |||||
| def _check_params(input_dim, output_dim, data, keep_prob, | |||||
| rel_activation, layer_activation): | |||||
| if not isinstance(input_dim, list): | |||||
| raise ValueError('input_dim must be a list') | |||||
| if not output_dim: | |||||
| raise ValueError('output_dim must be specified') | |||||
| if not isinstance(output_dim, list): | |||||
| output_dim = [output_dim] * len(data.node_types) | |||||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||||
| raise ValueError('data must be of type Data or PreparedData') | |||||