IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

249 lignes
7.4KB

  1. import torch
  2. from typing import List, \
  3. Set
  4. import time
  5. def _equal(x: torch.Tensor, y: torch.Tensor):
  6. if x.is_sparse ^ y.is_sparse:
  7. raise ValueError('Cannot mix sparse and dense tensors')
  8. if not x.is_sparse:
  9. return (x == y)
  10. return ((x - y).coalesce().values() == 0)
  11. def _sparse_coo_tensor(indices, values, size):
  12. ctor = { torch.float32: torch.sparse.FloatTensor,
  13. torch.float32: torch.sparse.DoubleTensor,
  14. torch.uint8: torch.sparse.ByteTensor,
  15. torch.long: torch.sparse.LongTensor,
  16. torch.int: torch.sparse.IntTensor,
  17. torch.short: torch.sparse.ShortTensor,
  18. torch.bool: torch.sparse.ByteTensor }[values.dtype]
  19. return ctor(indices, values, size)
  20. def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
  21. if len(adjacency_matrices) == 0:
  22. raise ValueError('adjacency_matrices must be non-empty')
  23. if not all([x.is_sparse for x in adjacency_matrices]):
  24. raise ValueError('All adjacency matrices must be sparse')
  25. indices = [ x.indices() for x in adjacency_matrices ]
  26. indices = torch.cat(indices, dim=1)
  27. values = torch.ones(indices.shape[1])
  28. res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
  29. res = res.coalesce()
  30. indices = res.indices()
  31. res = _sparse_coo_tensor(indices,
  32. torch.ones(indices.shape[1], dtype=torch.uint8),
  33. adjacency_matrices[0].shape)
  34. res = res.coalesce()
  35. return res
  36. def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
  37. rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
  38. if not adjacency_matrix.is_sparse:
  39. raise ValueError('adjacency_matrix must be sparse')
  40. if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
  41. raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
  42. t = time.time()
  43. rows = [ rows + row_vertex_count * i \
  44. for i in range(num_relation_types) ]
  45. print('rows took:', time.time() - t)
  46. t = time.time()
  47. rows = torch.cat(rows)
  48. print('cat took:', time.time() - t)
  49. # print('rows:', rows)
  50. # rows = set(rows.tolist())
  51. # print('rows:', rows)
  52. t = time.time()
  53. adj_mat = adjacency_matrix.coalesce()
  54. indices = adj_mat.indices()
  55. values = adj_mat.values()
  56. print('indices[0]:', indices[0])
  57. # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
  58. lookup = torch.zeros(row_vertex_count * num_relation_types,
  59. dtype=torch.uint8, device=adj_mat.device)
  60. lookup[rows] = 1
  61. values = values * lookup[indices[0]]
  62. mask = torch.nonzero(values > 0, as_tuple=True)[0]
  63. indices = indices[:, mask]
  64. values = values[mask]
  65. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  66. # res = res.coalesce()
  67. print('res:', res)
  68. print('"index_select()" took:', time.time() - t)
  69. return res
  70. selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
  71. # print('selection:', selection)
  72. selection = torch.nonzero(selection, as_tuple=True)[0]
  73. # print('selection:', selection)
  74. indices = indices[:, selection]
  75. values = values[selection]
  76. print('"index_select()" took:', time.time() - t)
  77. t = time.time()
  78. res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  79. print('_sparse_coo_tensor() took:', time.time() - t)
  80. return res
  81. # t = time.time()
  82. # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
  83. # print('index_select took:', time.time() - t)
  84. t = time.time()
  85. adj_mat = adj_mat.coalesce()
  86. print('coalesce() took:', time.time() - t)
  87. indices = adj_mat.indices()
  88. # print('indices:', indices)
  89. values = adj_mat.values()
  90. t = time.time()
  91. indices[0] = rows[indices[0]]
  92. print('Lookup took:', time.time() - t)
  93. t = time.time()
  94. adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
  95. print('_sparse_coo_tensor() took:', time.time() - t)
  96. return adj_mat
  97. def _sparse_diag_cat(matrices: List[torch.Tensor]):
  98. if len(matrices) == 0:
  99. raise ValueError('The list of matrices must be non-empty')
  100. if not all(m.is_sparse for m in matrices):
  101. raise ValueError('All matrices must be sparse')
  102. if not all(len(m.shape) == 2 for m in matrices):
  103. raise ValueError('All matrices must be 2D')
  104. indices = []
  105. values = []
  106. row_offset = 0
  107. col_offset = 0
  108. for m in matrices:
  109. ind = m._indices().clone()
  110. ind[0] += row_offset
  111. ind[1] += col_offset
  112. indices.append(ind)
  113. values.append(m._values())
  114. row_offset += m.shape[0]
  115. col_offset += m.shape[1]
  116. indices = torch.cat(indices, dim=1)
  117. values = torch.cat(values)
  118. return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
  119. def _cat(matrices: List[torch.Tensor]):
  120. if len(matrices) == 0:
  121. raise ValueError('Empty list passed to _cat()')
  122. n = sum(a.is_sparse for a in matrices)
  123. if n != 0 and n != len(matrices):
  124. raise ValueError('All matrices must have the same layout (dense or sparse)')
  125. if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
  126. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  127. if not matrices[0].is_sparse:
  128. return torch.cat(matrices)
  129. total_rows = sum(a.shape[0] for a in matrices)
  130. indices = []
  131. values = []
  132. row_offset = 0
  133. for a in matrices:
  134. ind = a._indices().clone()
  135. val = a._values()
  136. ind[0] += row_offset
  137. ind = ind.transpose(0, 1)
  138. indices.append(ind)
  139. values.append(val)
  140. row_offset += a.shape[0]
  141. indices = torch.cat(indices).transpose(0, 1)
  142. values = torch.cat(values)
  143. res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
  144. return res
  145. def _mm(a: torch.Tensor, b: torch.Tensor):
  146. if a.is_sparse:
  147. return torch.sparse.mm(a, b)
  148. else:
  149. return torch.mm(a, b)
  150. def _select_rows(a: torch.Tensor, rows: torch.Tensor):
  151. if not a.is_sparse:
  152. return a[rows]
  153. indices = a.indices()
  154. values = a.values()
  155. mask = torch.zeros(a.shape[0])
  156. mask[rows] = 1
  157. if mask.sum() != len(rows):
  158. raise ValueError('Rows must be unique')
  159. mask = mask[indices[0]]
  160. mask = torch.nonzero(mask, as_tuple=True)[0]
  161. new_rows[rows] = torch.arange(len(rows))
  162. new_rows = new_rows[indices[0]]
  163. indices = indices[:, mask]
  164. indices[0] = new_rows
  165. values = values[mask]
  166. res = _sparse_coo_tensor(indices, values,
  167. size=(len(rows), a.shape[1]))
  168. return res
  169. def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \
  170. List[torch.Tensor]:
  171. tot = sum(vertex_type_counts)
  172. # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0)
  173. # print('indices.shape:', indices.shape)
  174. ofs = 0
  175. res = []
  176. for cnt in vertex_type_counts:
  177. ind = torch.cat([
  178. torch.arange(cnt).view(1, -1),
  179. torch.arange(ofs, ofs+cnt).view(1, -1)
  180. ])
  181. val = torch.ones(cnt)
  182. x = _sparse_coo_tensor(ind, val, size=(cnt, tot))
  183. x = x.to(device)
  184. res.append(x)
  185. ofs += cnt
  186. return res