IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
5.9KB

  1. import torch
  2. from typing import List, \
  3. Set
  4. import time
  5. def _equal(x: torch.Tensor, y: torch.Tensor):
  6. if x.is_sparse ^ y.is_sparse:
  7. raise ValueError('Cannot mix sparse and dense tensors')
  8. if not x.is_sparse:
  9. return (x == y)
  10. return ((x - y).coalesce().values() == 0)
  11. def _sparse_coo_tensor(indices, values, size):
  12. ctor = { torch.float32: torch.sparse.FloatTensor,
  13. torch.float32: torch.sparse.DoubleTensor,
  14. torch.uint8: torch.sparse.ByteTensor,
  15. torch.long: torch.sparse.LongTensor,
  16. torch.int: torch.sparse.IntTensor,
  17. torch.short: torch.sparse.ShortTensor,
  18. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  19. return ctor(indices, values, size)
  20. def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
  21. if len(adjacency_matrices) == 0:
  22. raise ValueError('adjacency_matrices must be non-empty')
  23. if not all([x.is_sparse for x in adjacency_matrices]):
  24. raise ValueError('All adjacency matrices must be sparse')
  25. indices = [ x.indices() for x in adjacency_matrices ]
  26. indices = torch.cat(indices, dim=1)
  27. values = torch.ones(indices.shape[1])
  28. res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
  29. res = res.coalesce()
  30. indices = res.indices()
  31. res = _sparse_coo_tensor(indices,
  32. torch.ones(indices.shape[1], dtype=torch.uint8))
  33. return res
  34. def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
  35. rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
  36. if not adjacency_matrix.is_sparse:
  37. raise ValueError('adjacency_matrix must be sparse')
  38. if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
  39. raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
  40. t = time.time()
  41. rows = [ rows + row_vertex_count * i \
  42. for i in range(num_relation_types) ]
  43. print('rows took:', time.time() - t)
  44. t = time.time()
  45. rows = torch.cat(rows)
  46. print('cat took:', time.time() - t)
  47. # print('rows:', rows)
  48. # rows = set(rows.tolist())
  49. # print('rows:', rows)
  50. t = time.time()
  51. adj_mat = adjacency_matrix.coalesce()
  52. indices = adj_mat.indices()
  53. values = adj_mat.values()
  54. print('indices[0]:', indices[0])
  55. # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
  56. lookup = torch.zeros(row_vertex_count * num_relation_types,
  57. dtype=torch.uint8, device=adj_mat.device)
  58. lookup[rows] = 1
  59. values = values * lookup[indices[0]]
  60. mask = torch.nonzero(values > 0, as_tuple=True)[0]
  61. indices = indices[:, mask]
  62. values = values[mask]
  63. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  64. # res = res.coalesce()
  65. print('res:', res)
  66. print('"index_select()" took:', time.time() - t)
  67. return res
  68. selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
  69. # print('selection:', selection)
  70. selection = torch.nonzero(selection, as_tuple=True)[0]
  71. # print('selection:', selection)
  72. indices = indices[:, selection]
  73. values = values[selection]
  74. print('"index_select()" took:', time.time() - t)
  75. t = time.time()
  76. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  77. print('_sparse_coo_tensor() took:', time.time() - t)
  78. return res
  79. # t = time.time()
  80. # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
  81. # print('index_select took:', time.time() - t)
  82. t = time.time()
  83. adj_mat = adj_mat.coalesce()
  84. print('coalesce() took:', time.time() - t)
  85. indices = adj_mat.indices()
  86. # print('indices:', indices)
  87. values = adj_mat.values()
  88. t = time.time()
  89. indices[0] = rows[indices[0]]
  90. print('Lookup took:', time.time() - t)
  91. t = time.time()
  92. adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  93. print('_sparse_coo_tensor() took:', time.time() - t)
  94. return adj_mat
  95. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  96. if len(matrices) == 0:
  97. raise ValueError('The list of matrices must be non-empty')
  98. if not all(m.is_sparse for m in matrices):
  99. raise ValueError('All matrices must be sparse')
  100. if not all(len(m.shape) == 2 for m in matrices):
  101. raise ValueError('All matrices must be 2D')
  102. indices = []
  103. values = []
  104. row_offset = 0
  105. col_offset = 0
  106. for m in matrices:
  107. ind = m._indices().clone()
  108. ind[0] += row_offset
  109. ind[1] += col_offset
  110. indices.append(ind)
  111. values.append(m._values())
  112. row_offset += m.shape[0]
  113. col_offset += m.shape[1]
  114. indices = torch.cat(indices, dim=1)
  115. values = torch.cat(values)
  116. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  117. def _cat(matrices: List[torch.Tensor]):
  118. if len(matrices) == 0:
  119. raise ValueError('Empty list passed to _cat()')
  120. n = sum(a.is_sparse for a in matrices)
  121. if n != 0 and n != len(matrices):
  122. raise ValueError('All matrices must have the same layout (dense or sparse)')
  123. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  124. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  125. if not matrices[0].is_sparse:
  126. return torch.cat(matrices)
  127. total_rows = sum(a.shape[0] for a in matrices)
  128. indices = []
  129. values = []
  130. row_offset = 0
  131. for a in matrices:
  132. ind = a._indices().clone()
  133. val = a._values()
  134. ind[0] += row_offset
  135. ind = ind.transpose(0, 1)
  136. indices.append(ind)
  137. values.append(val)
  138. row_offset += a.shape[0]
  139. indices = torch.cat(indices).transpose(0, 1)
  140. values = torch.cat(values)
  141. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  142. return res