IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
4.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import scipy.sparse as sp
  7. import torch
  8. def _check_tensor(adj_mat):
  9. if not isinstance(adj_mat, torch.Tensor):
  10. raise ValueError('adj_mat must be a torch.Tensor')
  11. def _check_sparse(adj_mat):
  12. if not adj_mat.is_sparse:
  13. raise ValueError('adj_mat must be sparse')
  14. def _check_dense(adj_mat):
  15. if adj_mat.is_sparse:
  16. raise ValueError('adj_mat must be dense')
  17. def _check_square(adj_mat):
  18. if len(adj_mat.shape) != 2 or \
  19. adj_mat.shape[0] != adj_mat.shape[1]:
  20. raise ValueError('adj_mat must be a square matrix')
  21. def _check_2d(adj_mat):
  22. if len(adj_mat.shape) != 2:
  23. raise ValueError('adj_mat must be a square matrix')
  24. def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  25. _check_tensor(adj_mat)
  26. _check_sparse(adj_mat)
  27. _check_square(adj_mat)
  28. adj_mat = adj_mat.coalesce()
  29. indices = adj_mat.indices()
  30. values = adj_mat.values()
  31. eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype,
  32. device=adj_mat.device).view(1, -1)
  33. eye_indices = torch.cat((eye_indices, eye_indices), 0)
  34. eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype,
  35. device=adj_mat.device)
  36. indices = torch.cat((indices, eye_indices), 1)
  37. values = torch.cat((values, eye_values), 0)
  38. adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape)
  39. return adj_mat
  40. def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  41. _check_tensor(adj_mat)
  42. _check_sparse(adj_mat)
  43. _check_square(adj_mat)
  44. adj_mat = add_eye_sparse(adj_mat)
  45. adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
  46. return adj_mat
  47. def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  48. _check_tensor(adj_mat)
  49. _check_dense(adj_mat)
  50. _check_square(adj_mat)
  51. adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype,
  52. device=adj_mat.device)
  53. adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
  54. return adj_mat
  55. def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
  56. _check_tensor(adj_mat)
  57. _check_square(adj_mat)
  58. if adj_mat.is_sparse:
  59. return norm_adj_mat_one_node_type_sparse(adj_mat)
  60. else:
  61. return norm_adj_mat_one_node_type_dense(adj_mat)
  62. def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  63. _check_tensor(adj_mat)
  64. _check_sparse(adj_mat)
  65. _check_2d(adj_mat)
  66. adj_mat = adj_mat.coalesce()
  67. indices = adj_mat.indices()
  68. values = adj_mat.values()
  69. degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device)
  70. degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
  71. degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
  72. degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
  73. values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
  74. adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape)
  75. return adj_mat
  76. def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  77. _check_tensor(adj_mat)
  78. _check_dense(adj_mat)
  79. _check_2d(adj_mat)
  80. degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
  81. degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
  82. degrees_row = torch.sqrt(degrees_row)
  83. degrees_col = torch.sqrt(degrees_col)
  84. adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
  85. adj_mat = adj_mat / degrees_col
  86. return adj_mat
  87. def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
  88. _check_tensor(adj_mat)
  89. _check_2d(adj_mat)
  90. if adj_mat.is_sparse:
  91. return norm_adj_mat_two_node_types_sparse(adj_mat)
  92. else:
  93. return norm_adj_mat_two_node_types_dense(adj_mat)