IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
7.4KB

  1. import torch
  2. from typing import List, \
  3. Set
  4. import time
  5. def _equal(x: torch.Tensor, y: torch.Tensor):
  6. if x.is_sparse ^ y.is_sparse:
  7. raise ValueError('Cannot mix sparse and dense tensors')
  8. if not x.is_sparse:
  9. return (x == y)
  10. return ((x - y).coalesce().values() == 0)
  11. def _sparse_coo_tensor(indices, values, size):
  12. ctor = { torch.float32: torch.sparse.FloatTensor,
  13. torch.float32: torch.sparse.DoubleTensor,
  14. torch.uint8: torch.sparse.ByteTensor,
  15. torch.long: torch.sparse.LongTensor,
  16. torch.int: torch.sparse.IntTensor,
  17. torch.short: torch.sparse.ShortTensor,
  18. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  19. return ctor(indices, values, size)
  20. def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
  21. if len(adjacency_matrices) == 0:
  22. raise ValueError('adjacency_matrices must be non-empty')
  23. if not all([x.is_sparse for x in adjacency_matrices]):
  24. raise ValueError('All adjacency matrices must be sparse')
  25. indices = [ x.indices() for x in adjacency_matrices ]
  26. indices = torch.cat(indices, dim=1)
  27. values = torch.ones(indices.shape[1])
  28. res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
  29. res = res.coalesce()
  30. indices = res.indices()
  31. res = _sparse_coo_tensor(indices,
  32. torch.ones(indices.shape[1], dtype=torch.uint8),
  33. adjacency_matrices[0].shape)
  34. res = res.coalesce()
  35. return res
  36. def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
  37. rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
  38. if not adjacency_matrix.is_sparse:
  39. raise ValueError('adjacency_matrix must be sparse')
  40. if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
  41. raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
  42. t = time.time()
  43. rows = [ rows + row_vertex_count * i \
  44. for i in range(num_relation_types) ]
  45. print('rows took:', time.time() - t)
  46. t = time.time()
  47. rows = torch.cat(rows)
  48. print('cat took:', time.time() - t)
  49. # print('rows:', rows)
  50. # rows = set(rows.tolist())
  51. # print('rows:', rows)
  52. t = time.time()
  53. adj_mat = adjacency_matrix.coalesce()
  54. indices = adj_mat.indices()
  55. values = adj_mat.values()
  56. print('indices[0]:', indices[0])
  57. # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
  58. lookup = torch.zeros(row_vertex_count * num_relation_types,
  59. dtype=torch.uint8, device=adj_mat.device)
  60. lookup[rows] = 1
  61. values = values * lookup[indices[0]]
  62. mask = torch.nonzero(values > 0, as_tuple=True)[0]
  63. indices = indices[:, mask]
  64. values = values[mask]
  65. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  66. # res = res.coalesce()
  67. print('res:', res)
  68. print('"index_select()" took:', time.time() - t)
  69. return res
  70. selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
  71. # print('selection:', selection)
  72. selection = torch.nonzero(selection, as_tuple=True)[0]
  73. # print('selection:', selection)
  74. indices = indices[:, selection]
  75. values = values[selection]
  76. print('"index_select()" took:', time.time() - t)
  77. t = time.time()
  78. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  79. print('_sparse_coo_tensor() took:', time.time() - t)
  80. return res
  81. # t = time.time()
  82. # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
  83. # print('index_select took:', time.time() - t)
  84. t = time.time()
  85. adj_mat = adj_mat.coalesce()
  86. print('coalesce() took:', time.time() - t)
  87. indices = adj_mat.indices()
  88. # print('indices:', indices)
  89. values = adj_mat.values()
  90. t = time.time()
  91. indices[0] = rows[indices[0]]
  92. print('Lookup took:', time.time() - t)
  93. t = time.time()
  94. adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  95. print('_sparse_coo_tensor() took:', time.time() - t)
  96. return adj_mat
  97. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  98. if len(matrices) == 0:
  99. raise ValueError('The list of matrices must be non-empty')
  100. if not all(m.is_sparse for m in matrices):
  101. raise ValueError('All matrices must be sparse')
  102. if not all(len(m.shape) == 2 for m in matrices):
  103. raise ValueError('All matrices must be 2D')
  104. indices = []
  105. values = []
  106. row_offset = 0
  107. col_offset = 0
  108. for m in matrices:
  109. ind = m._indices().clone()
  110. ind[0] += row_offset
  111. ind[1] += col_offset
  112. indices.append(ind)
  113. values.append(m._values())
  114. row_offset += m.shape[0]
  115. col_offset += m.shape[1]
  116. indices = torch.cat(indices, dim=1)
  117. values = torch.cat(values)
  118. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  119. def _cat(matrices: List[torch.Tensor]):
  120. if len(matrices) == 0:
  121. raise ValueError('Empty list passed to _cat()')
  122. n = sum(a.is_sparse for a in matrices)
  123. if n != 0 and n != len(matrices):
  124. raise ValueError('All matrices must have the same layout (dense or sparse)')
  125. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  126. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  127. if not matrices[0].is_sparse:
  128. return torch.cat(matrices)
  129. total_rows = sum(a.shape[0] for a in matrices)
  130. indices = []
  131. values = []
  132. row_offset = 0
  133. for a in matrices:
  134. ind = a._indices().clone()
  135. val = a._values()
  136. ind[0] += row_offset
  137. ind = ind.transpose(0, 1)
  138. indices.append(ind)
  139. values.append(val)
  140. row_offset += a.shape[0]
  141. indices = torch.cat(indices).transpose(0, 1)
  142. values = torch.cat(values)
  143. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  144. return res
  145. def _mm(a: torch.Tensor, b: torch.Tensor):
  146. if a.is_sparse:
  147. return torch.sparse.mm(a, b)
  148. else:
  149. return torch.mm(a, b)
  150. def _select_rows(a: torch.Tensor, rows: torch.Tensor):
  151. if not a.is_sparse:
  152. return a[rows]
  153. indices = a.indices()
  154. values = a.values()
  155. mask = torch.zeros(a.shape[0])
  156. mask[rows] = 1
  157. if mask.sum() != len(rows):
  158. raise ValueError('Rows must be unique')
  159. mask = mask[indices[0]]
  160. mask = torch.nonzero(mask, as_tuple=True)[0]
  161. new_rows[rows] = torch.arange(len(rows))
  162. new_rows = new_rows[indices[0]]
  163. indices = indices[:, mask]
  164. indices[0] = new_rows
  165. values = values[mask]
  166. res = _sparse_coo_tensor(indices, values,
  167. size=(len(rows), a.shape[1]))
  168. return res
  169. def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \
  170. List[torch.Tensor]:
  171. tot = sum(vertex_type_counts)
  172. # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0)
  173. # print('indices.shape:', indices.shape)
  174. ofs = 0
  175. res = []
  176. for cnt in vertex_type_counts:
  177. ind = torch.cat([
  178. torch.arange(cnt).view(1, -1),
  179. torch.arange(ofs, ofs+cnt).view(1, -1)
  180. ])
  181. val = torch.ones(cnt)
  182. x = _sparse_coo_tensor(ind, val, size=(cnt, tot))
  183. x = x.to(device)
  184. res.append(x)
  185. ofs += cnt
  186. return res