IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
4.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import scipy.sparse as sp
  7. import torch
  8. def _check_tensor(adj_mat):
  9. if not isinstance(adj_mat, torch.Tensor):
  10. raise ValueError('adj_mat must be a torch.Tensor')
  11. def _check_sparse(adj_mat):
  12. if not adj_mat.is_sparse:
  13. raise ValueError('adj_mat must be sparse')
  14. def _check_dense(adj_mat):
  15. if adj_mat.is_sparse:
  16. raise ValueError('adj_mat must be dense')
  17. def _check_square(adj_mat):
  18. if len(adj_mat.shape) != 2 or \
  19. adj_mat.shape[0] != adj_mat.shape[1]:
  20. raise ValueError('adj_mat must be a square matrix')
  21. def _check_2d(adj_mat):
  22. if len(adj_mat.shape) != 2:
  23. raise ValueError('adj_mat must be a square matrix')
  24. def _sparse_coo_tensor(indices, values, size):
  25. ctor = { torch.float32: torch.sparse.FloatTensor,
  26. torch.float32: torch.sparse.DoubleTensor,
  27. torch.uint8: torch.sparse.ByteTensor,
  28. torch.long: torch.sparse.LongTensor,
  29. torch.int: torch.sparse.IntTensor,
  30. torch.short: torch.sparse.ShortTensor }[values.dtype]
  31. return ctor(indices, values, size)
  32. def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  33. _check_tensor(adj_mat)
  34. _check_sparse(adj_mat)
  35. _check_square(adj_mat)
  36. adj_mat = adj_mat.coalesce()
  37. indices = adj_mat.indices()
  38. values = adj_mat.values()
  39. eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype,
  40. device=adj_mat.device).view(1, -1)
  41. eye_indices = torch.cat((eye_indices, eye_indices), 0)
  42. eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype,
  43. device=adj_mat.device)
  44. indices = torch.cat((indices, eye_indices), 1)
  45. values = torch.cat((values, eye_values), 0)
  46. adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
  47. return adj_mat
  48. def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  49. _check_tensor(adj_mat)
  50. _check_sparse(adj_mat)
  51. _check_square(adj_mat)
  52. adj_mat = add_eye_sparse(adj_mat)
  53. adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
  54. return adj_mat
  55. def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  56. _check_tensor(adj_mat)
  57. _check_dense(adj_mat)
  58. _check_square(adj_mat)
  59. adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype,
  60. device=adj_mat.device)
  61. adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
  62. return adj_mat
  63. def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
  64. _check_tensor(adj_mat)
  65. _check_square(adj_mat)
  66. if adj_mat.is_sparse:
  67. return norm_adj_mat_one_node_type_sparse(adj_mat)
  68. else:
  69. return norm_adj_mat_one_node_type_dense(adj_mat)
  70. def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  71. _check_tensor(adj_mat)
  72. _check_sparse(adj_mat)
  73. _check_2d(adj_mat)
  74. adj_mat = adj_mat.coalesce()
  75. indices = adj_mat.indices()
  76. values = adj_mat.values()
  77. degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device)
  78. degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
  79. degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
  80. degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
  81. values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
  82. adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
  83. return adj_mat
  84. def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  85. _check_tensor(adj_mat)
  86. _check_dense(adj_mat)
  87. _check_2d(adj_mat)
  88. degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
  89. degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
  90. degrees_row = torch.sqrt(degrees_row)
  91. degrees_col = torch.sqrt(degrees_col)
  92. adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
  93. adj_mat = adj_mat / degrees_col
  94. return adj_mat
  95. def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
  96. _check_tensor(adj_mat)
  97. _check_2d(adj_mat)
  98. if adj_mat.is_sparse:
  99. return norm_adj_mat_two_node_types_sparse(adj_mat)
  100. else:
  101. return norm_adj_mat_two_node_types_dense(adj_mat)