| @@ -0,0 +1,110 @@ | |||
| from typing import List, \ | |||
| Union, \ | |||
| Callable | |||
| from .data import Data | |||
| from .trainprep import PreparedData | |||
| import torch | |||
| from .weights import init_glorot | |||
| class FastConvLayer(torch.nn.Module): | |||
| adjacency_matrix: List[torch.Tensor] | |||
| adjacency_matrix_backward: List[torch.Tensor] | |||
| weight: List[torch.Tensor] | |||
| weight_backward: List[torch.Tensor] | |||
| def __init__(self, | |||
| input_dim: List[int], | |||
| output_dim: List[int], | |||
| data: Union[Data, PreparedData], | |||
| keep_prob: float = 1., | |||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x | |||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
| **kwargs): | |||
| super().__init__(**kwargs) | |||
| self._check_params(input_dim, output_dim, data, keep_prob, | |||
| rel_activation, layer_activation) | |||
| self.input_dim = input_dim | |||
| self.output_dim = output_dim | |||
| self.data = data | |||
| self.keep_prob = keep_prob | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.adjacency_matrix = None | |||
| self.adjacency_matrix_backward = None | |||
| self.weight = None | |||
| self.weight_backward = None | |||
| self.build() | |||
| def build(self): | |||
| self.adjacency_matrix = [] | |||
| self.adjacency_matrix_backward = [] | |||
| self.weight = [] | |||
| self.weight_backward = [] | |||
| for fam in self.data.relation_families: | |||
| adj_mat = [ rel.adjacency_matrix \ | |||
| for rel in fam.relation_types \ | |||
| if rel.adjacency_matrix is not None ] | |||
| adj_mat_back = [ rel.adjacency_matrix_backward \ | |||
| for rel in fam.relation_types \ | |||
| if rel.adjacency_matrix_backward is not None ] | |||
| weight = [ init_glorot(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row]) \ | |||
| for _ in range(len(adj_mat)) ] | |||
| weight_back = [ init_glorot(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row]) \ | |||
| for _ in range(len(adj_mat_back)) ] | |||
| adj_mat = torch.cat(adj_mat) \ | |||
| if len(adj_mat) > 0 \ | |||
| else None | |||
| adj_mat_back = torch.cat(adj_mat_back) \ | |||
| if len(adj_mat_back) > 0 \ | |||
| else None | |||
| self.adjacency_matrix.append(adj_mat) | |||
| self.adjacency_matrix_backward.append(adj_mat_back) | |||
| self.weight.append(weight) | |||
| self.weight_back.append(weight_back) | |||
| def forward(self, prev_layer_repr): | |||
| for i, fam in enumerate(self.data.relation_families): | |||
| repr_row = prev_layer_repr[fam.node_type_row] | |||
| repr_column = prev_layer_repr[fam.node_type_column] | |||
| adj_mat = self.adjacency_matrix[i] | |||
| adj_mat_back = self.adjacency_matrix_backward[i] | |||
| if adj_mat is not None: | |||
| x = dropout(repr_column, keep_prob=self.keep_prob) | |||
| x = torch.sparse.mm(x, self.weight[i]) \ | |||
| if x.is_sparse \ | |||
| else torch.mm(x, self.weight[i]) | |||
| x = torch.sparse.mm(adj_mat, repr_row) \ | |||
| if adj_mat.is_sparse \ | |||
| else torch.mm(adj_mat, repr_row) | |||
| x = self.rel_activation(x) | |||
| x = x.view(len(fam.relation_types), len(repr_row), -1) | |||
| if adj_mat_back is not None: | |||
| x = torch.sparse.mm(adj_mat_back, repr_row) \ | |||
| if adj_mat_back.is_sparse \ | |||
| else torch.mm(adj_mat_back, repr_row) | |||
| @staticmethod | |||
| def _check_params(input_dim, output_dim, data, keep_prob, | |||
| rel_activation, layer_activation): | |||
| if not isinstance(input_dim, list): | |||
| raise ValueError('input_dim must be a list') | |||
| if not output_dim: | |||
| raise ValueError('output_dim must be specified') | |||
| if not isinstance(output_dim, list): | |||
| output_dim = [output_dim] * len(data.node_types) | |||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||
| raise ValueError('data must be of type Data or PreparedData') | |||