diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 6f90eb8..81519fb 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -78,80 +78,61 @@ def _cat(matrices: List[torch.Tensor]): class FastGraphConv(torch.nn.Module): def __init__(self, - in_channels: List[int], - out_channels: List[int], - data: Union[Data, PreparedData], - relation_family: Union[RelationFamily, PreparedRelationFamily], + in_channels: int, + out_channels: int, + adjacency_matrices: List[torch.Tensor], keep_prob: float = 1., - acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, **kwargs) -> None: + super().__init__(**kwargs) + in_channels = int(in_channels) out_channels = int(out_channels) - if not isinstance(data, Data) and not isinstance(data, PreparedData): - raise TypeError('data must be an instance of Data or PreparedData') - if not isinstance(relation_family, RelationFamily) and \ - not isinstance(relation_family, PreparedRelationFamily): - raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily') + if not isinstance(adjacency_matrices, list): + raise TypeError('adjacency_matrices must be a list') + if len(adjacency_matrices) == 0: + raise ValueError('adjacency_matrices must not be empty') + if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): + raise TypeError('adjacency_matrices elements must be of class torch.Tensor') + if not all(m.is_sparse for m in adjacency_matrices): + raise ValueError('adjacency_matrices elements must be sparse') keep_prob = float(keep_prob) if not isinstance(activation, types.FunctionType): raise TypeError('activation must be a function') - n_nodes_row = data.node_types[relation_family.node_type_row].count - n_nodes_column = data.node_types[relation_family.node_type_column].count - self.in_channels = in_channels self.out_channels = out_channels - self.data = data - self.relation_family = relation_family + self.adjacency_matrices = adjacency_matrices self.keep_prob = keep_prob self.activation = activation - self.weight = torch.cat([ - init_glorot(in_channels, out_channels) \ - for _ in range(len(relation_family.relation_types)) - ], dim=1) + self.num_row_nodes = len(adjacency_matrices[0]) + self.num_relation_types = len(adjacency_matrices) + + self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) - self.weight_backward = torch.cat([ + self.weights = torch.cat([ init_glorot(in_channels, out_channels) \ - for _ in range(len(relation_family.relation_types)) + for _ in range(self.num_relation_types) ], dim=1) - self.adjacency_matrix = _sparse_diag_cat([ - rel.adjacency_matrix \ - if rel.adjacency_matrix is not None \ - else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \ - for rel in relation_family.relation_types ]) - - self.adjacency_matrix_backward = _sparse_diag_cat([ - rel.adjacency_matrix_backward \ - if rel.adjacency_matrix_backward is not None \ - else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \ - for rel in relation_family.relation_types ]) - - def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: - repr_row = prev_layer_repr[self.relation_family.node_type_row] - repr_column = prev_layer_repr[self.relation_family.node_type_column] - - new_repr_row = torch.sparse.mm(repr_column, self.weight) \ - if repr_column.is_sparse \ - else torch.mm(repr_column, self.weight) - new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \ - if self.adjacency_matrix.is_sparse \ - else torch.mm(self.adjacency_matrix, new_repr_row) - new_repr_row = new_repr_row.view(len(self.relation_family.relation_types), - len(repr_row), self.out_channels) - - new_repr_column = torch.sparse.mm(repr_row, self.weight) \ - if repr_row.is_sparse \ - else torch.mm(repr_row, self.weight) - new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \ - if self.adjacency_matrix_backward.is_sparse \ - else torch.mm(self.adjacency_matrix_backward, new_repr_column) - new_repr_column = new_repr_column.view(len(self.relation_family.relation_types), - len(repr_column), self.out_channels) - - return (new_repr_row, new_repr_column) + def forward(self, x) -> torch.Tensor: + if self.keep_prob < 1.: + x = dropout(x, self.keep_prob) + res = torch.sparse.mm(x, self.weights) \ + if x.is_sparse \ + else torch.mm(x, self.weights) + res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) + res = torch.cat(res) + res = torch.sparse.mm(self.adjacency_matrices, res) \ + if self.adjacency_matrices.is_sparse \ + else torch.mm(self.adjacency_matrices, res) + res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) + if self.activation is not None: + res = self.activation(res) + + return res class FastConvLayer(torch.nn.Module): diff --git a/tests/icosagon/test_fastconv.py b/tests/icosagon/test_fastconv.py index 799da5d..407248a 100644 --- a/tests/icosagon/test_fastconv.py +++ b/tests/icosagon/test_fastconv.py @@ -1,7 +1,10 @@ from icosagon.fastconv import _sparse_diag_cat, \ - _cat + _cat, \ + FastGraphConv from icosagon.data import _equal import torch +import pdb +import time def test_sparse_diag_cat_01(): @@ -58,3 +61,28 @@ def test_cat_02(): assert res.shape == (35, 10) assert res.is_sparse assert torch.all(res.to_dense() == ground_truth) + + +def test_fast_graph_conv_01(): + # pdb.set_trace() + adj_mats = [ torch.rand(10, 15).round().to_sparse() \ + for _ in range(23) ] + fgc = FastGraphConv(32, 64, adj_mats) + in_repr = torch.rand(15, 32) + _ = fgc(in_repr) + + +def test_fast_graph_conv_02(): + t = time.time() + m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() + adj_mats = [ m for _ in range(1300) ] + print('Generating adj_mats took:', time.time() - t) + t = time.time() + fgc = FastGraphConv(32, 64, adj_mats) + print('FGC constructor took:', time.time() - t) + in_repr = torch.rand(2000, 32) + + for _ in range(3): + t = time.time() + _ = fgc(in_repr) + print('FGC forward pass took:', time.time() - t)