From e2ae97add31b6d38f1f982d8701c8f52726dba2b Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 24 Jul 2020 13:26:30 +0200 Subject: [PATCH] Work on FastGraphConv. --- src/icosagon/fastconv.py | 87 ++++++++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 16 deletions(-) diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 74e0319..02086c4 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -1,11 +1,14 @@ from typing import List, \ Union, \ Callable -from .data import Data -from .trainprep import PreparedData +from .data import Data, \ + RelationFamily +from .trainprep import PreparedData, \ + PreparedRelationFamily import torch from .weights import init_glorot from .normalize import _sparse_coo_tensor +import types def _sparse_diag_cat(matrices: List[torch.Tensor]): @@ -75,28 +78,80 @@ def _cat(matrices: List[torch.Tensor]): class FastGraphConv(torch.nn.Module): def __init__(self, - in_channels: int, - out_channels: int, - adjacency_matrix: List[torch.Tensor], - **kwargs): + in_channels: List[int], + out_channels: List[int], + data: Union[Data, PreparedData], + relation_family: Union[RelationFamily, PreparedRelationFamily] + keep_prob: float = 1., + acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + **kwargs) -> None: + + in_channels = int(in_channels) + out_channels = int(out_channels) + if not isinstance(data, Data) and not isinstance(data, PreparedData): + raise TypeError('data must be an instance of Data or PreparedData') + if not isinstance(relation_family, RelationFamily) and \ + not isinstance(relation_family, PreparedRelationFamily): + raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily') + keep_prob = float(keep_prob) + if not isinstance(activation, types.FunctionType): + raise TypeError('activation must be a function') + + n_nodes_row = data.node_types[relation_family.node_type_row].count + n_nodes_column = data.node_types[relation_family.node_type_column].count self.in_channels = in_channels self.out_channels = out_channels + self.data = data + self.relation_family = relation_family + self.keep_prob = keep_prob + self.activation = activation + self.weight = torch.cat([ init_glorot(in_channels, out_channels) \ - for _ in adjacency_matrix + for _ in range(len(relation_family.relation_types)) ], dim=1) - self.adjacency_matrix = _cat(adjacency_matrix) - def forward(self, x): - x = torch.sparse.mm(x, self.weight) \ - if x.is_sparse \ - else torch.mm(x, self.weight) - x = torch.sparse.mm(self.adjacency_matrix, x) \ - if self.adjacency_matrix.is_sparse \ - else torch.mm(self.adjacency_matrix, x) - return x + self.weight_backward = torch.cat([ + init_glorot(in_channels, out_channels) \ + for _ in range(len(relation_family.relation_types)) + ], dim=1) + self.adjacency_matrix = _sparse_diag_cat([ + rel.adjacency_matrix \ + if rel.adjacency_matrix is not None \ + else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \ + for rel in relation_family.relation_types ]) + + self.adjacency_matrix_backward = _sparse_diag_cat([ + rel.adjacency_matrix_backward \ + if rel.adjacency_matrix_backward is not None \ + else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \ + for rel in relation_family.relation_types ]) + + def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: + repr_row = prev_layer_repr[self.relation_family.node_type_row] + repr_column = prev_layer_repr[self.relation_family.node_type_column] + + new_repr_row = torch.sparse.mm(repr_column, self.weight) \ + if repr_column.is_sparse \ + else torch.mm(repr_column, self.weight) + new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \ + if self.adjacency_matrix.is_sparse \ + else torch.mm(self.adjacency_matrix, new_repr_row) + new_repr_row = new_repr_row.view(len(self.relation_family.relation_types), + len(repr_row), self.out_channels) + + new_repr_column = torch.sparse.mm(repr_row, self.weight) \ + if repr_row.is_sparse \ + else torch.mm(repr_row, self.weight) + new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \ + if self.adjacency_matrix_backward.is_sparse \ + else torch.mm(self.adjacency_matrix_backward, new_repr_column) + new_repr_column = new_repr_column.view(len(self.relation_family.relation_types), + len(repr_column), self.out_channels) + + return (new_repr_row, new_repr_column) class FastConvLayer(torch.nn.Module):