IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
8.2KB

  1. import torch
  2. from typing import List, \
  3. Set
  4. import time
  5. def _diag(x: torch.Tensor, make_sparse: bool=False):
  6. if len(x.shape) < 1 or len(x.shape) > 2:
  7. raise ValueError('Matrix or vector expected')
  8. if not x.is_sparse and not make_sparse:
  9. return torch.diag(x)
  10. if len(x.shape) == 1:
  11. indices = torch.arange(len(x)).view(1, -1)
  12. indices = torch.cat([ indices, indices ])
  13. return _sparse_coo_tensor(indices, x.to_dense(), (len(x),) * 2)
  14. values = x.values()
  15. indices = x.indices()
  16. mask = torch.nonzero(indices[0] == indices[1], as_tuple=True)[0]
  17. indices = torch.flatten(indices[0, mask])
  18. order = torch.argsort(indices)
  19. values = values[mask][order]
  20. res = torch.zeros(min(x.shape[0], x.shape[1]), dtype=values.dtype)
  21. res[indices] = values
  22. return res
  23. def _equal(x: torch.Tensor, y: torch.Tensor):
  24. if x.is_sparse ^ y.is_sparse:
  25. raise ValueError('Cannot mix sparse and dense tensors')
  26. if not x.is_sparse:
  27. return (x == y)
  28. return ((x - y).coalesce().values() == 0)
  29. def _sparse_coo_tensor(indices, values, size):
  30. ctor = { torch.float32: torch.sparse.FloatTensor,
  31. torch.float32: torch.sparse.DoubleTensor,
  32. torch.uint8: torch.sparse.ByteTensor,
  33. torch.long: torch.sparse.LongTensor,
  34. torch.int: torch.sparse.IntTensor,
  35. torch.short: torch.sparse.ShortTensor,
  36. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  37. return ctor(indices, values, size)
  38. def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
  39. if len(adjacency_matrices) == 0:
  40. raise ValueError('adjacency_matrices must be non-empty')
  41. if not all([x.is_sparse for x in adjacency_matrices]):
  42. raise ValueError('All adjacency matrices must be sparse')
  43. indices = [ x.indices() for x in adjacency_matrices ]
  44. indices = torch.cat(indices, dim=1)
  45. values = torch.ones(indices.shape[1])
  46. res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
  47. res = res.coalesce()
  48. indices = res.indices()
  49. res = _sparse_coo_tensor(indices,
  50. torch.ones(indices.shape[1], dtype=torch.uint8),
  51. adjacency_matrices[0].shape)
  52. res = res.coalesce()
  53. return res
  54. def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
  55. rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
  56. if not adjacency_matrix.is_sparse:
  57. raise ValueError('adjacency_matrix must be sparse')
  58. if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
  59. raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
  60. t = time.time()
  61. rows = [ rows + row_vertex_count * i \
  62. for i in range(num_relation_types) ]
  63. print('rows took:', time.time() - t)
  64. t = time.time()
  65. rows = torch.cat(rows)
  66. print('cat took:', time.time() - t)
  67. # print('rows:', rows)
  68. # rows = set(rows.tolist())
  69. # print('rows:', rows)
  70. t = time.time()
  71. adj_mat = adjacency_matrix.coalesce()
  72. indices = adj_mat.indices()
  73. values = adj_mat.values()
  74. print('indices[0]:', indices[0])
  75. # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
  76. lookup = torch.zeros(row_vertex_count * num_relation_types,
  77. dtype=torch.uint8, device=adj_mat.device)
  78. lookup[rows] = 1
  79. values = values * lookup[indices[0]]
  80. mask = torch.nonzero(values > 0, as_tuple=True)[0]
  81. indices = indices[:, mask]
  82. values = values[mask]
  83. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  84. # res = res.coalesce()
  85. print('res:', res)
  86. print('"index_select()" took:', time.time() - t)
  87. return res
  88. selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
  89. # print('selection:', selection)
  90. selection = torch.nonzero(selection, as_tuple=True)[0]
  91. # print('selection:', selection)
  92. indices = indices[:, selection]
  93. values = values[selection]
  94. print('"index_select()" took:', time.time() - t)
  95. t = time.time()
  96. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  97. print('_sparse_coo_tensor() took:', time.time() - t)
  98. return res
  99. # t = time.time()
  100. # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
  101. # print('index_select took:', time.time() - t)
  102. t = time.time()
  103. adj_mat = adj_mat.coalesce()
  104. print('coalesce() took:', time.time() - t)
  105. indices = adj_mat.indices()
  106. # print('indices:', indices)
  107. values = adj_mat.values()
  108. t = time.time()
  109. indices[0] = rows[indices[0]]
  110. print('Lookup took:', time.time() - t)
  111. t = time.time()
  112. adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  113. print('_sparse_coo_tensor() took:', time.time() - t)
  114. return adj_mat
  115. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  116. if len(matrices) == 0:
  117. raise ValueError('The list of matrices must be non-empty')
  118. if not all(m.is_sparse for m in matrices):
  119. raise ValueError('All matrices must be sparse')
  120. if not all(len(m.shape) == 2 for m in matrices):
  121. raise ValueError('All matrices must be 2D')
  122. indices = []
  123. values = []
  124. row_offset = 0
  125. col_offset = 0
  126. for m in matrices:
  127. ind = m._indices().clone()
  128. ind[0] += row_offset
  129. ind[1] += col_offset
  130. indices.append(ind)
  131. values.append(m._values())
  132. row_offset += m.shape[0]
  133. col_offset += m.shape[1]
  134. indices = torch.cat(indices, dim=1)
  135. values = torch.cat(values)
  136. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  137. def _cat(matrices: List[torch.Tensor]):
  138. if len(matrices) == 0:
  139. raise ValueError('Empty list passed to _cat()')
  140. n = sum(a.is_sparse for a in matrices)
  141. if n != 0 and n != len(matrices):
  142. raise ValueError('All matrices must have the same layout (dense or sparse)')
  143. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  144. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  145. if not matrices[0].is_sparse:
  146. return torch.cat(matrices)
  147. total_rows = sum(a.shape[0] for a in matrices)
  148. indices = []
  149. values = []
  150. row_offset = 0
  151. for a in matrices:
  152. ind = a._indices().clone()
  153. val = a._values()
  154. ind[0] += row_offset
  155. ind = ind.transpose(0, 1)
  156. indices.append(ind)
  157. values.append(val)
  158. row_offset += a.shape[0]
  159. indices = torch.cat(indices).transpose(0, 1)
  160. values = torch.cat(values)
  161. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  162. return res
  163. def _mm(a: torch.Tensor, b: torch.Tensor):
  164. if a.is_sparse:
  165. return torch.sparse.mm(a, b)
  166. else:
  167. return torch.mm(a, b)
  168. def _select_rows(a: torch.Tensor, rows: torch.Tensor):
  169. if not a.is_sparse:
  170. return a[rows]
  171. indices = a.indices()
  172. values = a.values()
  173. mask = torch.zeros(a.shape[0])
  174. mask[rows] = 1
  175. if mask.sum() != len(rows):
  176. raise ValueError('Rows must be unique')
  177. mask = mask[indices[0]]
  178. mask = torch.nonzero(mask, as_tuple=True)[0]
  179. new_rows[rows] = torch.arange(len(rows))
  180. new_rows = new_rows[indices[0]]
  181. indices = indices[:, mask]
  182. indices[0] = new_rows
  183. values = values[mask]
  184. res = _sparse_coo_tensor(indices, values,
  185. size=(len(rows), a.shape[1]))
  186. return res
  187. def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \
  188. List[torch.Tensor]:
  189. tot = sum(vertex_type_counts)
  190. # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0)
  191. # print('indices.shape:', indices.shape)
  192. ofs = 0
  193. res = []
  194. for cnt in vertex_type_counts:
  195. ind = torch.cat([
  196. torch.arange(cnt).view(1, -1),
  197. torch.arange(ofs, ofs+cnt).view(1, -1)
  198. ])
  199. val = torch.ones(cnt)
  200. x = _sparse_coo_tensor(ind, val, size=(cnt, tot))
  201. x = x.to(device)
  202. res.append(x)
  203. ofs += cnt
  204. return res