diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py new file mode 100644 index 0000000..31428b2 --- /dev/null +++ b/src/icosagon/fastconv.py @@ -0,0 +1,110 @@ +from typing import List, \ + Union, \ + Callable +from .data import Data +from .trainprep import PreparedData +import torch +from .weights import init_glorot + + +class FastConvLayer(torch.nn.Module): + adjacency_matrix: List[torch.Tensor] + adjacency_matrix_backward: List[torch.Tensor] + weight: List[torch.Tensor] + weight_backward: List[torch.Tensor] + + def __init__(self, + input_dim: List[int], + output_dim: List[int], + data: Union[Data, PreparedData], + keep_prob: float = 1., + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x + layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, + **kwargs): + + super().__init__(**kwargs) + + self._check_params(input_dim, output_dim, data, keep_prob, + rel_activation, layer_activation) + + self.input_dim = input_dim + self.output_dim = output_dim + self.data = data + self.keep_prob = keep_prob + self.rel_activation = rel_activation + self.layer_activation = layer_activation + + self.adjacency_matrix = None + self.adjacency_matrix_backward = None + self.weight = None + self.weight_backward = None + self.build() + + def build(self): + self.adjacency_matrix = [] + self.adjacency_matrix_backward = [] + self.weight = [] + self.weight_backward = [] + for fam in self.data.relation_families: + adj_mat = [ rel.adjacency_matrix \ + for rel in fam.relation_types \ + if rel.adjacency_matrix is not None ] + adj_mat_back = [ rel.adjacency_matrix_backward \ + for rel in fam.relation_types \ + if rel.adjacency_matrix_backward is not None ] + weight = [ init_glorot(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row]) \ + for _ in range(len(adj_mat)) ] + weight_back = [ init_glorot(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row]) \ + for _ in range(len(adj_mat_back)) ] + adj_mat = torch.cat(adj_mat) \ + if len(adj_mat) > 0 \ + else None + adj_mat_back = torch.cat(adj_mat_back) \ + if len(adj_mat_back) > 0 \ + else None + self.adjacency_matrix.append(adj_mat) + self.adjacency_matrix_backward.append(adj_mat_back) + self.weight.append(weight) + self.weight_back.append(weight_back) + + def forward(self, prev_layer_repr): + for i, fam in enumerate(self.data.relation_families): + repr_row = prev_layer_repr[fam.node_type_row] + repr_column = prev_layer_repr[fam.node_type_column] + + adj_mat = self.adjacency_matrix[i] + adj_mat_back = self.adjacency_matrix_backward[i] + + if adj_mat is not None: + x = dropout(repr_column, keep_prob=self.keep_prob) + x = torch.sparse.mm(x, self.weight[i]) \ + if x.is_sparse \ + else torch.mm(x, self.weight[i]) + x = torch.sparse.mm(adj_mat, repr_row) \ + if adj_mat.is_sparse \ + else torch.mm(adj_mat, repr_row) + x = self.rel_activation(x) + x = x.view(len(fam.relation_types), len(repr_row), -1) + + if adj_mat_back is not None: + x = torch.sparse.mm(adj_mat_back, repr_row) \ + if adj_mat_back.is_sparse \ + else torch.mm(adj_mat_back, repr_row) + + @staticmethod + def _check_params(input_dim, output_dim, data, keep_prob, + rel_activation, layer_activation): + + if not isinstance(input_dim, list): + raise ValueError('input_dim must be a list') + + if not output_dim: + raise ValueError('output_dim must be specified') + + if not isinstance(output_dim, list): + output_dim = [output_dim] * len(data.node_types) + + if not isinstance(data, Data) and not isinstance(data, PreparedData): + raise ValueError('data must be of type Data or PreparedData')