IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

normalize.py 4.3KB

4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
4 jaren geleden
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import scipy.sparse as sp
  7. import torch
  8. def _check_tensor(adj_mat):
  9. if not isinstance(adj_mat, torch.Tensor):
  10. raise ValueError('adj_mat must be a torch.Tensor')
  11. def _check_sparse(adj_mat):
  12. if not adj_mat.is_sparse:
  13. raise ValueError('adj_mat must be sparse')
  14. def _check_dense(adj_mat):
  15. if adj_mat.is_sparse:
  16. raise ValueError('adj_mat must be dense')
  17. def _check_square(adj_mat):
  18. if len(adj_mat.shape) != 2 or \
  19. adj_mat.shape[0] != adj_mat.shape[1]:
  20. raise ValueError('adj_mat must be a square matrix')
  21. def _check_2d(adj_mat):
  22. if len(adj_mat.shape) != 2:
  23. raise ValueError('adj_mat must be a square matrix')
  24. def _sparse_coo_tensor(indices, values, size):
  25. ctor = { torch.float32: torch.sparse.FloatTensor,
  26. torch.float32: torch.sparse.DoubleTensor,
  27. torch.uint8: torch.sparse.ByteTensor,
  28. torch.long: torch.sparse.LongTensor,
  29. torch.int: torch.sparse.IntTensor,
  30. torch.short: torch.sparse.ShortTensor,
  31. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  32. return ctor(indices, values, size)
  33. def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  34. _check_tensor(adj_mat)
  35. _check_sparse(adj_mat)
  36. _check_square(adj_mat)
  37. adj_mat = adj_mat.coalesce()
  38. indices = adj_mat.indices()
  39. values = adj_mat.values()
  40. eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype,
  41. device=adj_mat.device).view(1, -1)
  42. eye_indices = torch.cat((eye_indices, eye_indices), 0)
  43. eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype,
  44. device=adj_mat.device)
  45. indices = torch.cat((indices, eye_indices), 1)
  46. values = torch.cat((values, eye_values), 0)
  47. adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
  48. return adj_mat
  49. def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  50. _check_tensor(adj_mat)
  51. _check_sparse(adj_mat)
  52. _check_square(adj_mat)
  53. adj_mat = add_eye_sparse(adj_mat)
  54. adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
  55. return adj_mat
  56. def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  57. _check_tensor(adj_mat)
  58. _check_dense(adj_mat)
  59. _check_square(adj_mat)
  60. adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype,
  61. device=adj_mat.device)
  62. adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
  63. return adj_mat
  64. def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
  65. _check_tensor(adj_mat)
  66. _check_square(adj_mat)
  67. if adj_mat.is_sparse:
  68. return norm_adj_mat_one_node_type_sparse(adj_mat)
  69. else:
  70. return norm_adj_mat_one_node_type_dense(adj_mat)
  71. def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
  72. _check_tensor(adj_mat)
  73. _check_sparse(adj_mat)
  74. _check_2d(adj_mat)
  75. adj_mat = adj_mat.coalesce()
  76. indices = adj_mat.indices()
  77. values = adj_mat.values()
  78. degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device)
  79. degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
  80. degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
  81. degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
  82. values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
  83. adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
  84. return adj_mat
  85. def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
  86. _check_tensor(adj_mat)
  87. _check_dense(adj_mat)
  88. _check_2d(adj_mat)
  89. degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
  90. degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
  91. degrees_row = torch.sqrt(degrees_row)
  92. degrees_col = torch.sqrt(degrees_col)
  93. adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
  94. adj_mat = adj_mat / degrees_col
  95. return adj_mat
  96. def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
  97. _check_tensor(adj_mat)
  98. _check_2d(adj_mat)
  99. if adj_mat.is_sparse:
  100. return norm_adj_mat_two_node_types_sparse(adj_mat)
  101. else:
  102. return norm_adj_mat_two_node_types_dense(adj_mat)