IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

342 lines
12KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import torch
  7. import torch.utils.data
  8. from typing import List, \
  9. Union, \
  10. Tuple
  11. from .data import Data, \
  12. EdgeType
  13. from .cumcount import cumcount
  14. import time
  15. import multiprocessing
  16. import multiprocessing.pool
  17. from itertools import product, \
  18. repeat
  19. from functools import reduce
  20. def fixed_unigram_candidate_sampler(
  21. true_classes: torch.Tensor,
  22. num_repeats: torch.Tensor,
  23. unigrams: torch.Tensor,
  24. distortion: float = 1.) -> torch.Tensor:
  25. assert isinstance(true_classes, torch.Tensor)
  26. assert isinstance(num_repeats, torch.Tensor)
  27. assert isinstance(unigrams, torch.Tensor)
  28. distortion = float(distortion)
  29. if len(true_classes.shape) != 2:
  30. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  31. if len(num_repeats.shape) != 1:
  32. raise ValueError('num_repeats must be 1D')
  33. if torch.any((unigrams > 0).sum() - \
  34. (true_classes >= 0).sum(dim=1) < \
  35. num_repeats):
  36. raise ValueError('Not enough classes to choose from')
  37. true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1)
  38. true_classes = torch.cat([
  39. true_classes,
  40. torch.full(( len(true_classes), torch.max(num_repeats) ), -1,
  41. dtype=true_classes.dtype)
  42. ], dim=1)
  43. indices = torch.repeat_interleave(torch.arange(len(true_classes)), num_repeats)
  44. indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
  45. indices.view(-1, 1) ], dim=1)
  46. result = torch.zeros(len(indices), dtype=torch.long)
  47. while len(indices) > 0:
  48. print(len(indices))
  49. candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  50. candidates = torch.tensor(list(candidates)).view(-1, 1)
  51. inner_order = torch.argsort(candidates[:, 0])
  52. indices_np = indices[inner_order].detach().cpu().numpy()
  53. outer_order = np.argsort(indices_np[:, 1], kind='stable')
  54. outer_order = torch.tensor(outer_order, device=inner_order.device)
  55. candidates = candidates[inner_order][outer_order]
  56. indices = indices[inner_order][outer_order]
  57. mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool)
  58. # can_cum = cumcount(candidates[:, 0])
  59. can_diff = torch.cat([ torch.tensor([1]), candidates[1:, 0] - candidates[:-1, 0] ])
  60. ind_cum = cumcount(indices[:, 1])
  61. repeated = (can_diff == 0) & (ind_cum > 0)
  62. # TODO: this is wrong, still requires work
  63. mask = mask | repeated
  64. updated = indices[~mask]
  65. if len(updated) > 0:
  66. ofs = true_class_count[updated[:, 1]] + \
  67. cumcount(updated[:, 1])
  68. true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
  69. true_class_count[updated[:, 1]] = ofs + 1
  70. result[indices[:, 0]] = candidates.transpose(0, 1)
  71. indices = indices[mask]
  72. return result
  73. def fixed_unigram_candidate_sampler_slow(
  74. true_classes: torch.Tensor,
  75. num_repeats: torch.Tensor,
  76. unigrams: torch.Tensor,
  77. distortion: float = 1.) -> torch.Tensor:
  78. assert isinstance(true_classes, torch.Tensor)
  79. assert isinstance(num_repeats, torch.Tensor)
  80. assert isinstance(unigrams, torch.Tensor)
  81. distortion = float(distortion)
  82. if len(true_classes.shape) != 2:
  83. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  84. if len(num_repeats.shape) != 1:
  85. raise ValueError('num_repeats must be 1D')
  86. if torch.any((unigrams > 0).sum() - \
  87. (true_classes >= 0).sum(dim=1) < \
  88. num_repeats):
  89. raise ValueError('Not enough classes to choose from')
  90. res = []
  91. if distortion != 1.:
  92. unigrams = unigrams.to(torch.float64)
  93. unigrams = unigrams ** distortion
  94. def fun(i):
  95. if i and i % 100 == 0:
  96. print(i)
  97. if num_repeats[i] == 0:
  98. return []
  99. pos = torch.flatten(true_classes[i, :])
  100. pos = pos[pos >= 0]
  101. w = unigrams.clone().detach()
  102. w[pos] = 0
  103. sampler = torch.utils.data.WeightedRandomSampler(w,
  104. num_repeats[i].item(), replacement=False)
  105. res = list(sampler)
  106. return res
  107. with multiprocessing.pool.ThreadPool() as p:
  108. res = p.map(fun, range(len(num_repeats)))
  109. res = reduce(list.__add__, res, [])
  110. return torch.tensor(res)
  111. def fixed_unigram_candidate_sampler_old(
  112. true_classes: torch.Tensor,
  113. num_repeats: torch.Tensor,
  114. unigrams: torch.Tensor,
  115. distortion: float = 1.) -> torch.Tensor:
  116. if len(true_classes.shape) != 2:
  117. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  118. if len(num_repeats.shape) != 1:
  119. raise ValueError('num_repeats must be 1D')
  120. if torch.any((unigrams > 0).sum() - \
  121. (true_classes >= 0).sum(dim=1) < \
  122. num_repeats):
  123. raise ValueError('Not enough classes to choose from')
  124. num_rows = true_classes.shape[0]
  125. print('true_classes.shape:', true_classes.shape)
  126. # unigrams = np.array(unigrams)
  127. if distortion != 1.:
  128. unigrams = unigrams.to(torch.float64) ** distortion
  129. print('unigrams:', unigrams)
  130. indices = torch.arange(num_rows)
  131. indices = torch.repeat_interleave(indices, num_repeats)
  132. indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
  133. indices.view(-1, 1) ], dim=1)
  134. num_samples = len(indices)
  135. result = torch.zeros(num_samples, dtype=torch.long)
  136. print('num_rows:', num_rows, 'num_samples:', num_samples)
  137. while len(indices) > 0:
  138. print('len(indices):', len(indices))
  139. print('indices:', indices)
  140. sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  141. candidates = torch.tensor(list(sampler))
  142. candidates = candidates.view(len(indices), 1)
  143. print('candidates:', candidates)
  144. print('true_classes:', true_classes[indices[:, 1], :])
  145. result[indices[:, 0]] = candidates.transpose(0, 1)
  146. print('result:', result)
  147. mask = (candidates == true_classes[indices[:, 1], :])
  148. mask = mask.sum(1).to(torch.bool)
  149. # append_true_classes = torch.full(( len(true_classes), ), -1)
  150. # append_true_classes[~mask] = torch.flatten(candidates)[~mask]
  151. # true_classes = torch.cat([
  152. # append_true_classes.view(-1, 1),
  153. # true_classes
  154. # ], dim=1)
  155. print('mask:', mask)
  156. indices = indices[mask]
  157. # result[indices] = 0
  158. return result
  159. def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
  160. Tuple[torch.Tensor, torch.Tensor]:
  161. if adj_mat.is_sparse:
  162. adj_mat = adj_mat.coalesce()
  163. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
  164. device=adj_mat.device)
  165. degrees = degrees.index_add(0, adj_mat.indices()[1],
  166. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
  167. device=adj_mat.device))
  168. edges_pos = adj_mat.indices().transpose(0, 1)
  169. else:
  170. degrees = adj_mat.sum(0)
  171. edges_pos = torch.nonzero(adj_mat, as_tuple=False)
  172. return edges_pos, degrees
  173. def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  174. indices = adj_mat.indices()
  175. row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
  176. #print('indices[0]:', indices[0], count[indices[0]])
  177. row_count = row_count.index_add(0, indices[0],
  178. torch.ones(indices.shape[1], dtype=torch.long))
  179. #print('count:', count)
  180. max_true_classes = torch.max(row_count).item()
  181. #print('max_true_classes:', max_true_classes)
  182. true_classes = torch.full((adj_mat.shape[0], max_true_classes),
  183. -1, dtype=torch.long)
  184. # inv = torch.unique(indices[0], return_inverse=True)
  185. # indices = indices.copy()
  186. # true_classes[indices[0], 0] = indices[1]
  187. t = time.time()
  188. cc = cumcount(indices[0])
  189. print('cumcount() took:', time.time() - t)
  190. # cc = torch.tensor(cc)
  191. t = time.time()
  192. true_classes[indices[0], cc] = indices[1]
  193. print('assignment took:', time.time() - t)
  194. ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
  195. for i in range(indices.shape[1]):
  196. # print('looping...')
  197. row = indices[0, i]
  198. col = indices[1, i]
  199. #print('row:', row, 'col:', col, 'count[row]:', count[row])
  200. true_classes[row, count[row]] = col
  201. count[row] += 1 '''
  202. # t = time.time()
  203. # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
  204. # print('repeat_interleave() took:', time.time() - t)
  205. return true_classes, row_count
  206. def negative_sample_adj_mat(adj_mat: torch.Tensor,
  207. remove_diagonal: bool=False) -> torch.Tensor:
  208. if not isinstance(adj_mat, torch.Tensor):
  209. raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
  210. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  211. degrees = degrees.to(torch.float32) + 1.0 / torch.numel(adj_mat)
  212. true_classes, row_count = get_true_classes(adj_mat)
  213. if remove_diagonal:
  214. true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1),
  215. true_classes ], dim=1)
  216. # true_classes = edges_pos[:, 1].view(-1, 1)
  217. # print('true_classes:', true_classes)
  218. neg_neighbors = fixed_unigram_candidate_sampler(
  219. true_classes, row_count, degrees, 0.75).to(adj_mat.device)
  220. print('neg_neighbors:', neg_neighbors)
  221. pos_vertices = torch.repeat_interleave(torch.arange(len(adj_mat)),
  222. row_count)
  223. edges_neg = torch.cat([ pos_vertices.view(-1, 1),
  224. neg_neighbors.view(-1, 1) ], 1)
  225. adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
  226. values=torch.ones(len(edges_neg)), size=adj_mat.shape,
  227. dtype=adj_mat.dtype, device=adj_mat.device)
  228. adj_mat_neg = adj_mat_neg.coalesce()
  229. indices = adj_mat_neg.indices()
  230. adj_mat_neg = torch.sparse_coo_tensor(indices,
  231. torch.ones(indices.shape[1]), adj_mat.shape,
  232. dtype=adj_mat.dtype, device=adj_mat.device)
  233. adj_mat_neg = adj_mat_neg.coalesce()
  234. return adj_mat_neg
  235. def negative_sample_data(data: Data) -> Data:
  236. new_edge_types = {}
  237. res = Data(target_value=0)
  238. for vt in data.vertex_types:
  239. res.add_vertex_type(vt.name, vt.count)
  240. for key, et in data.edge_types.items():
  241. print('key:', key)
  242. adjacency_matrices_neg = []
  243. for adj_mat in et.adjacency_matrices:
  244. remove_diagonal = True \
  245. if et.vertex_type_row == et.vertex_type_column \
  246. else False
  247. adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal)
  248. adjacency_matrices_neg.append(adj_mat_neg)
  249. res.add_edge_type(et.name,
  250. et.vertex_type_row, et.vertex_type_column,
  251. adjacency_matrices_neg, et.decoder_factory)
  252. #new_et = EdgeType(et.name, et.vertex_type_row,
  253. # et.vertex_type_column, adjacency_matrices_neg,
  254. # et.decoder_factory, et.total_connectivity)
  255. #new_edge_types[key] = new_et
  256. #res = Data(data.vertex_types, new_edge_types)
  257. return res
  258. def merge_data(pos_data: Data, neg_data: Data) -> Data:
  259. assert isinstance(pos_data, Data)
  260. assert isinstance(neg_data, Data)
  261. res = PosNegData()
  262. for vt in pos_data.vertex_types:
  263. res.add_vertex_type(vt.name, vt.count)
  264. for key, pos_et in pos_data.edge_types.items():
  265. neg_et = neg_data.edge_types[key]
  266. res.add_edge_type(pos_et.name,
  267. pos_et.vertex_type_row, pos_et.vertex_type_column,
  268. pos_et.adjacency_matrices, neg_et.adjacency_matrices,
  269. pos_et.decoder_factory)