| @@ -1,11 +1,14 @@ | |||
| from typing import List, \ | |||
| Union, \ | |||
| Callable | |||
| from .data import Data | |||
| from .trainprep import PreparedData | |||
| from .data import Data, \ | |||
| RelationFamily | |||
| from .trainprep import PreparedData, \ | |||
| PreparedRelationFamily | |||
| import torch | |||
| from .weights import init_glorot | |||
| from .normalize import _sparse_coo_tensor | |||
| import types | |||
| def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||
| @@ -75,28 +78,80 @@ def _cat(matrices: List[torch.Tensor]): | |||
| class FastGraphConv(torch.nn.Module): | |||
| def __init__(self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| adjacency_matrix: List[torch.Tensor], | |||
| **kwargs): | |||
| in_channels: List[int], | |||
| out_channels: List[int], | |||
| data: Union[Data, PreparedData], | |||
| relation_family: Union[RelationFamily, PreparedRelationFamily] | |||
| keep_prob: float = 1., | |||
| acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| **kwargs) -> None: | |||
| in_channels = int(in_channels) | |||
| out_channels = int(out_channels) | |||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||
| raise TypeError('data must be an instance of Data or PreparedData') | |||
| if not isinstance(relation_family, RelationFamily) and \ | |||
| not isinstance(relation_family, PreparedRelationFamily): | |||
| raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily') | |||
| keep_prob = float(keep_prob) | |||
| if not isinstance(activation, types.FunctionType): | |||
| raise TypeError('activation must be a function') | |||
| n_nodes_row = data.node_types[relation_family.node_type_row].count | |||
| n_nodes_column = data.node_types[relation_family.node_type_column].count | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.data = data | |||
| self.relation_family = relation_family | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.weight = torch.cat([ | |||
| init_glorot(in_channels, out_channels) \ | |||
| for _ in adjacency_matrix | |||
| for _ in range(len(relation_family.relation_types)) | |||
| ], dim=1) | |||
| self.adjacency_matrix = _cat(adjacency_matrix) | |||
| def forward(self, x): | |||
| x = torch.sparse.mm(x, self.weight) \ | |||
| if x.is_sparse \ | |||
| else torch.mm(x, self.weight) | |||
| x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||
| if self.adjacency_matrix.is_sparse \ | |||
| else torch.mm(self.adjacency_matrix, x) | |||
| return x | |||
| self.weight_backward = torch.cat([ | |||
| init_glorot(in_channels, out_channels) \ | |||
| for _ in range(len(relation_family.relation_types)) | |||
| ], dim=1) | |||
| self.adjacency_matrix = _sparse_diag_cat([ | |||
| rel.adjacency_matrix \ | |||
| if rel.adjacency_matrix is not None \ | |||
| else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \ | |||
| for rel in relation_family.relation_types ]) | |||
| self.adjacency_matrix_backward = _sparse_diag_cat([ | |||
| rel.adjacency_matrix_backward \ | |||
| if rel.adjacency_matrix_backward is not None \ | |||
| else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \ | |||
| for rel in relation_family.relation_types ]) | |||
| def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: | |||
| repr_row = prev_layer_repr[self.relation_family.node_type_row] | |||
| repr_column = prev_layer_repr[self.relation_family.node_type_column] | |||
| new_repr_row = torch.sparse.mm(repr_column, self.weight) \ | |||
| if repr_column.is_sparse \ | |||
| else torch.mm(repr_column, self.weight) | |||
| new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \ | |||
| if self.adjacency_matrix.is_sparse \ | |||
| else torch.mm(self.adjacency_matrix, new_repr_row) | |||
| new_repr_row = new_repr_row.view(len(self.relation_family.relation_types), | |||
| len(repr_row), self.out_channels) | |||
| new_repr_column = torch.sparse.mm(repr_row, self.weight) \ | |||
| if repr_row.is_sparse \ | |||
| else torch.mm(repr_row, self.weight) | |||
| new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \ | |||
| if self.adjacency_matrix_backward.is_sparse \ | |||
| else torch.mm(self.adjacency_matrix_backward, new_repr_column) | |||
| new_repr_column = new_repr_column.view(len(self.relation_family.relation_types), | |||
| len(repr_column), self.out_channels) | |||
| return (new_repr_row, new_repr_column) | |||
| class FastConvLayer(torch.nn.Module): | |||