|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import numpy as np
- import scipy.sparse as sp
- import torch
-
-
- def _check_tensor(adj_mat):
- if not isinstance(adj_mat, torch.Tensor):
- raise ValueError('adj_mat must be a torch.Tensor')
-
-
- def _check_sparse(adj_mat):
- if not adj_mat.is_sparse:
- raise ValueError('adj_mat must be sparse')
-
-
- def _check_dense(adj_mat):
- if adj_mat.is_sparse:
- raise ValueError('adj_mat must be dense')
-
-
- def _check_square(adj_mat):
- if len(adj_mat.shape) != 2 or \
- adj_mat.shape[0] != adj_mat.shape[1]:
- raise ValueError('adj_mat must be a square matrix')
-
-
- def _check_2d(adj_mat):
- if len(adj_mat.shape) != 2:
- raise ValueError('adj_mat must be a square matrix')
-
-
- def _sparse_coo_tensor(indices, values, size):
- ctor = { torch.float32: torch.sparse.FloatTensor,
- torch.float32: torch.sparse.DoubleTensor,
- torch.uint8: torch.sparse.ByteTensor,
- torch.long: torch.sparse.LongTensor,
- torch.int: torch.sparse.IntTensor,
- torch.short: torch.sparse.ShortTensor,
- torch.bool: torch.sparse.ByteTensor }[values.dtype]
- return ctor(indices, values, size)
-
-
- def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_sparse(adj_mat)
- _check_square(adj_mat)
-
- adj_mat = adj_mat.coalesce()
- indices = adj_mat.indices()
- values = adj_mat.values()
-
- eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype,
- device=adj_mat.device).view(1, -1)
- eye_indices = torch.cat((eye_indices, eye_indices), 0)
- eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype,
- device=adj_mat.device)
-
- indices = torch.cat((indices, eye_indices), 1)
- values = torch.cat((values, eye_values), 0)
-
- adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
-
- return adj_mat
-
-
- def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_sparse(adj_mat)
- _check_square(adj_mat)
-
- adj_mat = add_eye_sparse(adj_mat)
- adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
-
- return adj_mat
-
-
- def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_dense(adj_mat)
- _check_square(adj_mat)
-
- adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype,
- device=adj_mat.device)
- adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
-
- return adj_mat
-
-
- def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_square(adj_mat)
-
- if adj_mat.is_sparse:
- return norm_adj_mat_one_node_type_sparse(adj_mat)
- else:
- return norm_adj_mat_one_node_type_dense(adj_mat)
-
-
- def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_sparse(adj_mat)
- _check_2d(adj_mat)
-
- adj_mat = adj_mat.coalesce()
- indices = adj_mat.indices()
- values = adj_mat.values()
- degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device)
- degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
- degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
- degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
- values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
- adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
-
- return adj_mat
-
-
- def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_dense(adj_mat)
- _check_2d(adj_mat)
-
- degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
- degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
- degrees_row = torch.sqrt(degrees_row)
- degrees_col = torch.sqrt(degrees_col)
- adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
- adj_mat = adj_mat / degrees_col
-
- return adj_mat
-
-
- def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
- _check_tensor(adj_mat)
- _check_2d(adj_mat)
-
- if adj_mat.is_sparse:
- return norm_adj_mat_two_node_types_sparse(adj_mat)
- else:
- return norm_adj_mat_two_node_types_dense(adj_mat)
|