IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
7.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import torch
  7. import torch.utils.data
  8. from typing import List, \
  9. Union, \
  10. Tuple
  11. from .data import Data, \
  12. EdgeType
  13. from .cumcount import cumcount
  14. import time
  15. def fixed_unigram_candidate_sampler(
  16. true_classes: torch.Tensor,
  17. num_repeats: torch.Tensor,
  18. unigrams: torch.Tensor,
  19. distortion: float = 1.) -> torch.Tensor:
  20. if len(true_classes.shape) != 2:
  21. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  22. if len(num_repeats.shape) != 1:
  23. raise ValueError('num_repeats must be 1D')
  24. if torch.any(len(unigrams) - \
  25. (true_classes >= 0).sum(dim=1) < \
  26. num_repeats):
  27. raise ValueError('Not enough classes to choose from')
  28. num_rows = true_classes.shape[0]
  29. print('true_classes.shape:', true_classes.shape)
  30. # unigrams = np.array(unigrams)
  31. if distortion != 1.:
  32. unigrams = unigrams.to(torch.float64) ** distortion
  33. print('unigrams:', unigrams)
  34. indices = torch.arange(num_rows)
  35. indices = torch.repeat_interleave(indices, num_repeats)
  36. indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
  37. indices.view(-1, 1) ], dim=1)
  38. num_samples = len(indices)
  39. result = torch.zeros(num_samples, dtype=torch.long)
  40. print('num_rows:', num_rows, 'num_samples:', num_samples)
  41. while len(indices) > 0:
  42. print('len(indices):', len(indices))
  43. print('indices:', indices)
  44. sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  45. candidates = torch.tensor(list(sampler))
  46. candidates = candidates.view(len(indices), 1)
  47. print('candidates:', candidates)
  48. print('true_classes:', true_classes[indices[:, 1], :])
  49. result[indices[:, 0]] = candidates.transpose(0, 1)
  50. print('result:', result)
  51. mask = (candidates == true_classes[indices[:, 1], :])
  52. mask = mask.sum(1).to(torch.bool)
  53. print('mask:', mask)
  54. indices = indices[mask]
  55. # result[indices] = 0
  56. return result
  57. def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
  58. Tuple[torch.Tensor, torch.Tensor]:
  59. if adj_mat.is_sparse:
  60. adj_mat = adj_mat.coalesce()
  61. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
  62. device=adj_mat.device)
  63. degrees = degrees.index_add(0, adj_mat.indices()[1],
  64. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
  65. device=adj_mat.device))
  66. edges_pos = adj_mat.indices().transpose(0, 1)
  67. else:
  68. degrees = adj_mat.sum(0)
  69. edges_pos = torch.nonzero(adj_mat, as_tuple=False)
  70. return edges_pos, degrees
  71. def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  72. indices = adj_mat.indices()
  73. row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
  74. #print('indices[0]:', indices[0], count[indices[0]])
  75. row_count = row_count.index_add(0, indices[0],
  76. torch.ones(indices.shape[1], dtype=torch.long))
  77. #print('count:', count)
  78. max_true_classes = torch.max(row_count).item()
  79. #print('max_true_classes:', max_true_classes)
  80. true_classes = torch.full((adj_mat.shape[0], max_true_classes),
  81. -1, dtype=torch.long)
  82. # inv = torch.unique(indices[0], return_inverse=True)
  83. # indices = indices.copy()
  84. # true_classes[indices[0], 0] = indices[1]
  85. t = time.time()
  86. cc = cumcount(indices[0].cpu().numpy())
  87. print('cumcount() took:', time.time() - t)
  88. cc = torch.tensor(cc)
  89. t = time.time()
  90. true_classes[indices[0], cc] = indices[1]
  91. print('assignment took:', time.time() - t)
  92. ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
  93. for i in range(indices.shape[1]):
  94. # print('looping...')
  95. row = indices[0, i]
  96. col = indices[1, i]
  97. #print('row:', row, 'col:', col, 'count[row]:', count[row])
  98. true_classes[row, count[row]] = col
  99. count[row] += 1 '''
  100. # t = time.time()
  101. # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
  102. # print('repeat_interleave() took:', time.time() - t)
  103. return true_classes, row_count
  104. def negative_sample_adj_mat(adj_mat: torch.Tensor,
  105. remove_diagonal: bool=False) -> torch.Tensor:
  106. if not isinstance(adj_mat, torch.Tensor):
  107. raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
  108. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  109. true_classes, row_count = get_true_classes(adj_mat)
  110. if remove_diagonal:
  111. true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1),
  112. true_classes ], dim=1)
  113. # true_classes = edges_pos[:, 1].view(-1, 1)
  114. # print('true_classes:', true_classes)
  115. neg_neighbors = fixed_unigram_candidate_sampler(
  116. true_classes, row_count, degrees, 0.75).to(adj_mat.device)
  117. print('neg_neighbors:', neg_neighbors)
  118. pos_vertices = torch.repeat_interleave(torch.arange(len(adj_mat)),
  119. row_count)
  120. edges_neg = torch.cat([ pos_vertices.view(-1, 1),
  121. neg_neighbors.view(-1, 1) ], 1)
  122. adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
  123. values=torch.ones(len(edges_neg)), size=adj_mat.shape,
  124. dtype=adj_mat.dtype, device=adj_mat.device)
  125. adj_mat_neg = adj_mat_neg.coalesce()
  126. indices = adj_mat_neg.indices()
  127. adj_mat_neg = torch.sparse_coo_tensor(indices,
  128. torch.ones(indices.shape[1]), adj_mat.shape,
  129. dtype=adj_mat.dtype, device=adj_mat.device)
  130. adj_mat_neg = adj_mat_neg.coalesce()
  131. return adj_mat_neg
  132. def negative_sample_data(data: Data) -> Data:
  133. new_edge_types = {}
  134. res = Data(target_value=0)
  135. for vt in data.vertex_types:
  136. res.add_vertex_type(vt.name, vt.count)
  137. for key, et in data.edge_types.items():
  138. adjacency_matrices_neg = []
  139. for adj_mat in et.adjacency_matrices:
  140. remove_diagonal = True \
  141. if et.vertex_type_row == et.vertex_type_column \
  142. else False
  143. adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal)
  144. adjacency_matrices_neg.append(adj_mat_neg)
  145. res.add_edge_type(et.name,
  146. et.vertex_type_row, et.vertex_type_column,
  147. adjacency_matrices_neg, et.decoder_factory)
  148. #new_et = EdgeType(et.name, et.vertex_type_row,
  149. # et.vertex_type_column, adjacency_matrices_neg,
  150. # et.decoder_factory, et.total_connectivity)
  151. #new_edge_types[key] = new_et
  152. #res = Data(data.vertex_types, new_edge_types)
  153. return res
  154. def merge_data(pos_data: Data, neg_data: Data) -> Data:
  155. assert isinstance(pos_data, Data)
  156. assert isinstance(neg_data, Data)
  157. res = PosNegData()
  158. for vt in pos_data.vertex_types:
  159. res.add_vertex_type(vt.name, vt.count)
  160. for key, pos_et in pos_data.edge_types.items():
  161. neg_et = neg_data.edge_types[key]
  162. res.add_edge_type(pos_et.name,
  163. pos_et.vertex_type_row, pos_et.vertex_type_column,
  164. pos_et.adjacency_matrices, neg_et.adjacency_matrices,
  165. pos_et.decoder_factory)