IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

272 lignes
8.2KB

  1. import torch
  2. from typing import List, \
  3. Set
  4. import time
  5. def _diag(x: torch.Tensor, make_sparse: bool=False):
  6. if len(x.shape) < 1 or len(x.shape) > 2:
  7. raise ValueError('Matrix or vector expected')
  8. if not x.is_sparse and not make_sparse:
  9. return torch.diag(x)
  10. if len(x.shape) == 1:
  11. indices = torch.arange(len(x)).view(1, -1)
  12. indices = torch.cat([ indices, indices ])
  13. return _sparse_coo_tensor(indices, x.to_dense(), (len(x),) * 2)
  14. values = x.values()
  15. indices = x.indices()
  16. mask = torch.nonzero(indices[0] == indices[1], as_tuple=True)[0]
  17. indices = torch.flatten(indices[0, mask])
  18. order = torch.argsort(indices)
  19. values = values[mask][order]
  20. res = torch.zeros(min(x.shape[0], x.shape[1]), dtype=values.dtype)
  21. res[indices] = values
  22. return res
  23. def _equal(x: torch.Tensor, y: torch.Tensor):
  24. if x.is_sparse ^ y.is_sparse:
  25. raise ValueError('Cannot mix sparse and dense tensors')
  26. if not x.is_sparse:
  27. return (x == y)
  28. return ((x - y).coalesce().values() == 0)
  29. def _sparse_coo_tensor(indices, values, size):
  30. ctor = { torch.float32: torch.sparse.FloatTensor,
  31. torch.float32: torch.sparse.DoubleTensor,
  32. torch.uint8: torch.sparse.ByteTensor,
  33. torch.long: torch.sparse.LongTensor,
  34. torch.int: torch.sparse.IntTensor,
  35. torch.short: torch.sparse.ShortTensor,
  36. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  37. return ctor(indices, values, size)
  38. def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
  39. if len(adjacency_matrices) == 0:
  40. raise ValueError('adjacency_matrices must be non-empty')
  41. if not all([x.is_sparse for x in adjacency_matrices]):
  42. raise ValueError('All adjacency matrices must be sparse')
  43. indices = [ x.indices() for x in adjacency_matrices ]
  44. indices = torch.cat(indices, dim=1)
  45. values = torch.ones(indices.shape[1])
  46. res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
  47. res = res.coalesce()
  48. indices = res.indices()
  49. res = _sparse_coo_tensor(indices,
  50. torch.ones(indices.shape[1], dtype=torch.uint8),
  51. adjacency_matrices[0].shape)
  52. res = res.coalesce()
  53. return res
  54. def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
  55. rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
  56. if not adjacency_matrix.is_sparse:
  57. raise ValueError('adjacency_matrix must be sparse')
  58. if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
  59. raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
  60. t = time.time()
  61. rows = [ rows + row_vertex_count * i \
  62. for i in range(num_relation_types) ]
  63. print('rows took:', time.time() - t)
  64. t = time.time()
  65. rows = torch.cat(rows)
  66. print('cat took:', time.time() - t)
  67. # print('rows:', rows)
  68. # rows = set(rows.tolist())
  69. # print('rows:', rows)
  70. t = time.time()
  71. adj_mat = adjacency_matrix.coalesce()
  72. indices = adj_mat.indices()
  73. values = adj_mat.values()
  74. print('indices[0]:', indices[0])
  75. # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
  76. lookup = torch.zeros(row_vertex_count * num_relation_types,
  77. dtype=torch.uint8, device=adj_mat.device)
  78. lookup[rows] = 1
  79. values = values * lookup[indices[0]]
  80. mask = torch.nonzero(values > 0, as_tuple=True)[0]
  81. indices = indices[:, mask]
  82. values = values[mask]
  83. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  84. # res = res.coalesce()
  85. print('res:', res)
  86. print('"index_select()" took:', time.time() - t)
  87. return res
  88. selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
  89. # print('selection:', selection)
  90. selection = torch.nonzero(selection, as_tuple=True)[0]
  91. # print('selection:', selection)
  92. indices = indices[:, selection]
  93. values = values[selection]
  94. print('"index_select()" took:', time.time() - t)
  95. t = time.time()
  96. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  97. print('_sparse_coo_tensor() took:', time.time() - t)
  98. return res
  99. # t = time.time()
  100. # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
  101. # print('index_select took:', time.time() - t)
  102. t = time.time()
  103. adj_mat = adj_mat.coalesce()
  104. print('coalesce() took:', time.time() - t)
  105. indices = adj_mat.indices()
  106. # print('indices:', indices)
  107. values = adj_mat.values()
  108. t = time.time()
  109. indices[0] = rows[indices[0]]
  110. print('Lookup took:', time.time() - t)
  111. t = time.time()
  112. adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  113. print('_sparse_coo_tensor() took:', time.time() - t)
  114. return adj_mat
  115. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  116. if len(matrices) == 0:
  117. raise ValueError('The list of matrices must be non-empty')
  118. if not all(m.is_sparse for m in matrices):
  119. raise ValueError('All matrices must be sparse')
  120. if not all(len(m.shape) == 2 for m in matrices):
  121. raise ValueError('All matrices must be 2D')
  122. indices = []
  123. values = []
  124. row_offset = 0
  125. col_offset = 0
  126. for m in matrices:
  127. ind = m._indices().clone()
  128. ind[0] += row_offset
  129. ind[1] += col_offset
  130. indices.append(ind)
  131. values.append(m._values())
  132. row_offset += m.shape[0]
  133. col_offset += m.shape[1]
  134. indices = torch.cat(indices, dim=1)
  135. values = torch.cat(values)
  136. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  137. def _cat(matrices: List[torch.Tensor]):
  138. if len(matrices) == 0:
  139. raise ValueError('Empty list passed to _cat()')
  140. n = sum(a.is_sparse for a in matrices)
  141. if n != 0 and n != len(matrices):
  142. raise ValueError('All matrices must have the same layout (dense or sparse)')
  143. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  144. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  145. if not matrices[0].is_sparse:
  146. return torch.cat(matrices)
  147. total_rows = sum(a.shape[0] for a in matrices)
  148. indices = []
  149. values = []
  150. row_offset = 0
  151. for a in matrices:
  152. ind = a._indices().clone()
  153. val = a._values()
  154. ind[0] += row_offset
  155. ind = ind.transpose(0, 1)
  156. indices.append(ind)
  157. values.append(val)
  158. row_offset += a.shape[0]
  159. indices = torch.cat(indices).transpose(0, 1)
  160. values = torch.cat(values)
  161. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  162. return res
  163. def _mm(a: torch.Tensor, b: torch.Tensor):
  164. if a.is_sparse:
  165. return torch.sparse.mm(a, b)
  166. else:
  167. return torch.mm(a, b)
  168. def _select_rows(a: torch.Tensor, rows: torch.Tensor):
  169. if not a.is_sparse:
  170. return a[rows]
  171. indices = a.indices()
  172. values = a.values()
  173. mask = torch.zeros(a.shape[0])
  174. mask[rows] = 1
  175. if mask.sum() != len(rows):
  176. raise ValueError('Rows must be unique')
  177. mask = mask[indices[0]]
  178. mask = torch.nonzero(mask, as_tuple=True)[0]
  179. new_rows[rows] = torch.arange(len(rows))
  180. new_rows = new_rows[indices[0]]
  181. indices = indices[:, mask]
  182. indices[0] = new_rows
  183. values = values[mask]
  184. res = _sparse_coo_tensor(indices, values,
  185. size=(len(rows), a.shape[1]))
  186. return res
  187. def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \
  188. List[torch.Tensor]:
  189. tot = sum(vertex_type_counts)
  190. # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0)
  191. # print('indices.shape:', indices.shape)
  192. ofs = 0
  193. res = []
  194. for cnt in vertex_type_counts:
  195. ind = torch.cat([
  196. torch.arange(cnt).view(1, -1),
  197. torch.arange(ofs, ofs+cnt).view(1, -1)
  198. ])
  199. val = torch.ones(cnt)
  200. x = _sparse_coo_tensor(ind, val, size=(cnt, tot))
  201. x = x.to(device)
  202. res.append(x)
  203. ofs += cnt
  204. return res