IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
5.4KB

  1. import torch
  2. from typing import List, \
  3. Set
  4. import time
  5. def _equal(x: torch.Tensor, y: torch.Tensor):
  6. if x.is_sparse ^ y.is_sparse:
  7. raise ValueError('Cannot mix sparse and dense tensors')
  8. if not x.is_sparse:
  9. return (x == y)
  10. return ((x - y).coalesce().values() == 0)
  11. def _sparse_coo_tensor(indices, values, size):
  12. ctor = { torch.float32: torch.sparse.FloatTensor,
  13. torch.float32: torch.sparse.DoubleTensor,
  14. torch.uint8: torch.sparse.ByteTensor,
  15. torch.long: torch.sparse.LongTensor,
  16. torch.int: torch.sparse.IntTensor,
  17. torch.short: torch.sparse.ShortTensor,
  18. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  19. return ctor(indices, values, size)
  20. def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
  21. if len(adjacency_matrices) == 0:
  22. raise ValueError('adjacency_matrices must be non-empty')
  23. if not all([x.is_sparse for x in adjacency_matrices]):
  24. raise ValueError('All adjacency matrices must be sparse')
  25. indices = [ x.indices() for x in adjacency_matrices ]
  26. indices = torch.cat(indices, dim=1)
  27. values = torch.ones(indices.shape[1])
  28. res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
  29. res = res.coalesce()
  30. indices = res.indices()
  31. res = _sparse_coo_tensor(indices,
  32. torch.ones(indices.shape[1], dtype=torch.uint8))
  33. return res
  34. def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
  35. rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
  36. if not adjacency_matrix.is_sparse:
  37. raise ValueError('adjacency_matrix must be sparse')
  38. if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
  39. raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
  40. t = time.time()
  41. rows = [ rows + row_vertex_count * i \
  42. for i in range(num_relation_types) ]
  43. print('rows took:', time.time() - t)
  44. t = time.time()
  45. rows = torch.cat(rows)
  46. print('cat took:', time.time() - t)
  47. # print('rows:', rows)
  48. rows = set(rows.tolist())
  49. # print('rows:', rows)
  50. t = time.time()
  51. adj_mat = adjacency_matrix.coalesce()
  52. indices = adj_mat.indices()
  53. values = adj_mat.values()
  54. print('indices[0]:', indices[0])
  55. print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
  56. selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
  57. # print('selection:', selection)
  58. selection = torch.nonzero(selection, as_tuple=True)[0]
  59. # print('selection:', selection)
  60. indices = indices[:, selection]
  61. values = values[selection]
  62. print('"index_select()" took:', time.time() - t)
  63. t = time.time()
  64. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  65. print('_sparse_coo_tensor() took:', time.time() - t)
  66. return res
  67. # t = time.time()
  68. # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
  69. # print('index_select took:', time.time() - t)
  70. t = time.time()
  71. adj_mat = adj_mat.coalesce()
  72. print('coalesce() took:', time.time() - t)
  73. indices = adj_mat.indices()
  74. # print('indices:', indices)
  75. values = adj_mat.values()
  76. t = time.time()
  77. indices[0] = rows[indices[0]]
  78. print('Lookup took:', time.time() - t)
  79. t = time.time()
  80. adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  81. print('_sparse_coo_tensor() took:', time.time() - t)
  82. return adj_mat
  83. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  84. if len(matrices) == 0:
  85. raise ValueError('The list of matrices must be non-empty')
  86. if not all(m.is_sparse for m in matrices):
  87. raise ValueError('All matrices must be sparse')
  88. if not all(len(m.shape) == 2 for m in matrices):
  89. raise ValueError('All matrices must be 2D')
  90. indices = []
  91. values = []
  92. row_offset = 0
  93. col_offset = 0
  94. for m in matrices:
  95. ind = m._indices().clone()
  96. ind[0] += row_offset
  97. ind[1] += col_offset
  98. indices.append(ind)
  99. values.append(m._values())
  100. row_offset += m.shape[0]
  101. col_offset += m.shape[1]
  102. indices = torch.cat(indices, dim=1)
  103. values = torch.cat(values)
  104. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  105. def _cat(matrices: List[torch.Tensor]):
  106. if len(matrices) == 0:
  107. raise ValueError('Empty list passed to _cat()')
  108. n = sum(a.is_sparse for a in matrices)
  109. if n != 0 and n != len(matrices):
  110. raise ValueError('All matrices must have the same layout (dense or sparse)')
  111. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  112. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  113. if not matrices[0].is_sparse:
  114. return torch.cat(matrices)
  115. total_rows = sum(a.shape[0] for a in matrices)
  116. indices = []
  117. values = []
  118. row_offset = 0
  119. for a in matrices:
  120. ind = a._indices().clone()
  121. val = a._values()
  122. ind[0] += row_offset
  123. ind = ind.transpose(0, 1)
  124. indices.append(ind)
  125. values.append(val)
  126. row_offset += a.shape[0]
  127. indices = torch.cat(indices).transpose(0, 1)
  128. values = torch.cat(values)
  129. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  130. return res