| @@ -1,11 +1,14 @@ | |||||
| from typing import List, \ | from typing import List, \ | ||||
| Union, \ | Union, \ | ||||
| Callable | Callable | ||||
| from .data import Data | |||||
| from .trainprep import PreparedData | |||||
| from .data import Data, \ | |||||
| RelationFamily | |||||
| from .trainprep import PreparedData, \ | |||||
| PreparedRelationFamily | |||||
| import torch | import torch | ||||
| from .weights import init_glorot | from .weights import init_glorot | ||||
| from .normalize import _sparse_coo_tensor | from .normalize import _sparse_coo_tensor | ||||
| import types | |||||
| def _sparse_diag_cat(matrices: List[torch.Tensor]): | def _sparse_diag_cat(matrices: List[torch.Tensor]): | ||||
| @@ -75,28 +78,80 @@ def _cat(matrices: List[torch.Tensor]): | |||||
| class FastGraphConv(torch.nn.Module): | class FastGraphConv(torch.nn.Module): | ||||
| def __init__(self, | def __init__(self, | ||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| adjacency_matrix: List[torch.Tensor], | |||||
| **kwargs): | |||||
| in_channels: List[int], | |||||
| out_channels: List[int], | |||||
| data: Union[Data, PreparedData], | |||||
| relation_family: Union[RelationFamily, PreparedRelationFamily] | |||||
| keep_prob: float = 1., | |||||
| acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| **kwargs) -> None: | |||||
| in_channels = int(in_channels) | |||||
| out_channels = int(out_channels) | |||||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||||
| raise TypeError('data must be an instance of Data or PreparedData') | |||||
| if not isinstance(relation_family, RelationFamily) and \ | |||||
| not isinstance(relation_family, PreparedRelationFamily): | |||||
| raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily') | |||||
| keep_prob = float(keep_prob) | |||||
| if not isinstance(activation, types.FunctionType): | |||||
| raise TypeError('activation must be a function') | |||||
| n_nodes_row = data.node_types[relation_family.node_type_row].count | |||||
| n_nodes_column = data.node_types[relation_family.node_type_column].count | |||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| self.out_channels = out_channels | self.out_channels = out_channels | ||||
| self.data = data | |||||
| self.relation_family = relation_family | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.weight = torch.cat([ | self.weight = torch.cat([ | ||||
| init_glorot(in_channels, out_channels) \ | init_glorot(in_channels, out_channels) \ | ||||
| for _ in adjacency_matrix | |||||
| for _ in range(len(relation_family.relation_types)) | |||||
| ], dim=1) | ], dim=1) | ||||
| self.adjacency_matrix = _cat(adjacency_matrix) | |||||
| def forward(self, x): | |||||
| x = torch.sparse.mm(x, self.weight) \ | |||||
| if x.is_sparse \ | |||||
| else torch.mm(x, self.weight) | |||||
| x = torch.sparse.mm(self.adjacency_matrix, x) \ | |||||
| if self.adjacency_matrix.is_sparse \ | |||||
| else torch.mm(self.adjacency_matrix, x) | |||||
| return x | |||||
| self.weight_backward = torch.cat([ | |||||
| init_glorot(in_channels, out_channels) \ | |||||
| for _ in range(len(relation_family.relation_types)) | |||||
| ], dim=1) | |||||
| self.adjacency_matrix = _sparse_diag_cat([ | |||||
| rel.adjacency_matrix \ | |||||
| if rel.adjacency_matrix is not None \ | |||||
| else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \ | |||||
| for rel in relation_family.relation_types ]) | |||||
| self.adjacency_matrix_backward = _sparse_diag_cat([ | |||||
| rel.adjacency_matrix_backward \ | |||||
| if rel.adjacency_matrix_backward is not None \ | |||||
| else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \ | |||||
| for rel in relation_family.relation_types ]) | |||||
| def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: | |||||
| repr_row = prev_layer_repr[self.relation_family.node_type_row] | |||||
| repr_column = prev_layer_repr[self.relation_family.node_type_column] | |||||
| new_repr_row = torch.sparse.mm(repr_column, self.weight) \ | |||||
| if repr_column.is_sparse \ | |||||
| else torch.mm(repr_column, self.weight) | |||||
| new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \ | |||||
| if self.adjacency_matrix.is_sparse \ | |||||
| else torch.mm(self.adjacency_matrix, new_repr_row) | |||||
| new_repr_row = new_repr_row.view(len(self.relation_family.relation_types), | |||||
| len(repr_row), self.out_channels) | |||||
| new_repr_column = torch.sparse.mm(repr_row, self.weight) \ | |||||
| if repr_row.is_sparse \ | |||||
| else torch.mm(repr_row, self.weight) | |||||
| new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \ | |||||
| if self.adjacency_matrix_backward.is_sparse \ | |||||
| else torch.mm(self.adjacency_matrix_backward, new_repr_column) | |||||
| new_repr_column = new_repr_column.view(len(self.relation_family.relation_types), | |||||
| len(repr_column), self.out_channels) | |||||
| return (new_repr_row, new_repr_column) | |||||
| class FastConvLayer(torch.nn.Module): | class FastConvLayer(torch.nn.Module): | ||||