IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
3.8KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import scipy.sparse as sp
  7. import torch
  8. def _check_tensor(adj_mat):
  9. if not isinstance(adj_mat, torch.Tensor):
  10. raise ValueError('adj_mat must be a torch.Tensor')
  11. def _check_sparse(adj_mat):
  12. if not adj_mat.is_sparse:
  13. raise ValueError('adj_mat must be sparse')
  14. def _check_dense(adj_mat):
  15. if adj_mat.is_sparse:
  16. raise ValueError('adj_mat must be dense')
  17. def _check_square(adj_mat):
  18. if len(adj_mat.shape) != 2 or \
  19. adj_mat.shape[0] != adj_mat.shape[1]:
  20. raise ValueError('adj_mat must be a square matrix')
  21. def _check_2d(adj_mat):
  22. if len(adj_mat.shape) != 2:
  23. raise ValueError('adj_mat must be a square matrix')
  24. def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  25. _check_tensor(adj_mat)
  26. _check_sparse(adj_mat)
  27. _check_square(adj_mat)
  28. adj_mat = adj_mat.coalesce()
  29. indices = adj_mat.indices()
  30. values = adj_mat.values()
  31. eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype).view(1, -1)
  32. eye_indices = torch.cat((eye_indices, eye_indices), 0)
  33. eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype)
  34. indices = torch.cat((indices, eye_indices), 1)
  35. values = torch.cat((values, eye_values), 0)
  36. adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape)
  37. return adj_mat
  38. def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  39. _check_tensor(adj_mat)
  40. _check_sparse(adj_mat)
  41. _check_square(adj_mat)
  42. adj_mat = add_eye_sparse(adj_mat)
  43. adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
  44. return adj_mat
  45. def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  46. _check_tensor(adj_mat)
  47. _check_dense(adj_mat)
  48. _check_square(adj_mat)
  49. adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype)
  50. adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
  51. return adj_mat
  52. def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
  53. _check_tensor(adj_mat)
  54. _check_square(adj_mat)
  55. if adj_mat.is_sparse:
  56. return norm_adj_mat_one_node_type_sparse(adj_mat)
  57. else:
  58. return norm_adj_mat_one_node_type_dense(adj_mat)
  59. def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  60. _check_tensor(adj_mat)
  61. _check_sparse(adj_mat)
  62. _check_2d(adj_mat)
  63. adj_mat = adj_mat.coalesce()
  64. indices = adj_mat.indices()
  65. values = adj_mat.values()
  66. degrees_row = torch.zeros(adj_mat.shape[0])
  67. degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
  68. degrees_col = torch.zeros(adj_mat.shape[1])
  69. degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
  70. values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
  71. adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape)
  72. return adj_mat
  73. def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  74. _check_tensor(adj_mat)
  75. _check_dense(adj_mat)
  76. _check_2d(adj_mat)
  77. degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
  78. degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
  79. degrees_row = torch.sqrt(degrees_row)
  80. degrees_col = torch.sqrt(degrees_col)
  81. adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
  82. adj_mat = adj_mat / degrees_col
  83. return adj_mat
  84. def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
  85. _check_tensor(adj_mat)
  86. _check_2d(adj_mat)
  87. if adj_mat.is_sparse:
  88. return norm_adj_mat_two_node_types_sparse(adj_mat)
  89. else:
  90. return norm_adj_mat_two_node_types_dense(adj_mat)