From dd41ba9bd004e8f1fa86a02228172a7218e9760f Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 22 Jul 2020 14:28:17 +0200 Subject: [PATCH] Add matrix-multiply illustration. --- docs/matrix-multiply.svg | 839 +++++++++++++++++++++++++++++++++++++++ src/icosagon/fastconv.py | 61 +++ 2 files changed, 900 insertions(+) create mode 100644 docs/matrix-multiply.svg diff --git a/docs/matrix-multiply.svg b/docs/matrix-multiply.svg new file mode 100644 index 0000000..4259ec3 --- /dev/null +++ b/docs/matrix-multiply.svg @@ -0,0 +1,839 @@ + + + + + + + + + + image/svg+xml + + + + + + + + x + + w1 + + w2 + + w3 + + w4 + + * + = + + x*w1 + + x*w2 + + x*w3 + + x*w4 + + + x*w1 + + x*w2 + + x*w3 + + x*w4 + + * + + + + + + A1 + A2 + A3 + A4 + = + + + + + + + + + + + + + + + + + + A1*x*w1 + A1*x*w2 + A1*x*w3 + A1*x*w4 + A2*x*w1 + A2*x*w2 + A2*x*w3 + A2*x*w4 + A3*x*w1 + A3*x*w2 + A3*x*w3 + A3*x*w4 + A4*x*w1 + A4*x*w2 + A4*x*w3 + A4*x*w4 + + + diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 31428b2..ba27660 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -7,6 +7,67 @@ import torch from .weights import init_glorot +def _cat(matrices: List[torch.Tensor]): + if len(matrices) == 0: + raise ValueError('Empty list passed to _cat()') + + n = sum(a.is_sparse for a in matrices) + if n != 0 and n != len(matrices): + raise ValueError('All matrices must have the same layout (dense or sparse)') + + if not all(a.shape[1:] == matrices[0].shape[1:]): + raise ValueError('All matrices must have the same dimensions apart from dimension 0') + + if not matrices[0].is_sparse: + return torch.cat(matrices) + + total_rows = sum(a.shape[0] for a in matrices) + indices = [] + values = [] + row_offset = 0 + + for a in matrices: + ind = a._indices().clone() + val = a._values() + ind[0] += row_offset + ind = ind.transpose(0, 1) + indices.append(ind) + values.append(val) + row_offset += a.shape[0] + + indices = torch.cat(indices).transpose(0, 1) + values = torch.cat(values) + + res = _sparse_coo_tensor(indices, values) + return res + + +class FastGraphConv(torch.nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + adjacency_matrix: List[torch.Tensor], + **kwargs): + + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = torch.cat([ + init_glorot(in_channels, out_channels) \ + for _ in adjacency_matrix + ], dim=1) + self.adjacency_matrix = _cat(adjacency_matrix) + + def forward(self, x): + x = torch.sparse.mm(x, self.weight) \ + if x.is_sparse \ + else torch.mm(x, self.weight) + x = torch.sparse.mm(self.adjacency_matrix, x) \ + if self.adjacency_matrix.is_sparse \ + else torch.mm(self.adjacency_matrix, x) + return x + + + class FastConvLayer(torch.nn.Module): adjacency_matrix: List[torch.Tensor] adjacency_matrix_backward: List[torch.Tensor]