IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
4.1KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import numpy as np
  6. import torch
  7. import torch.utils.data
  8. from typing import List, \
  9. Union, \
  10. Tuple
  11. from .data import Data, \
  12. EdgeType
  13. def fixed_unigram_candidate_sampler(
  14. true_classes: Union[np.array, torch.Tensor],
  15. unigrams: List[Union[int, float]],
  16. distortion: float = 1.):
  17. if isinstance(true_classes, torch.Tensor):
  18. true_classes = true_classes.detach().cpu().numpy()
  19. if isinstance(unigrams, torch.Tensor):
  20. unigrams = unigrams.detach().cpu().numpy()
  21. if len(true_classes.shape) != 2:
  22. raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
  23. num_samples = true_classes.shape[0]
  24. unigrams = np.array(unigrams)
  25. if distortion != 1.:
  26. unigrams = unigrams.astype(np.float64) ** distortion
  27. # print('unigrams:', unigrams)
  28. indices = np.arange(num_samples)
  29. result = np.zeros(num_samples, dtype=np.int64)
  30. while len(indices) > 0:
  31. # print('len(indices):', len(indices))
  32. sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
  33. candidates = np.array(list(sampler))
  34. candidates = np.reshape(candidates, (len(indices), 1))
  35. # print('candidates:', candidates)
  36. # print('true_classes:', true_classes[indices, :])
  37. result[indices] = candidates.T
  38. # print('result:', result)
  39. mask = (candidates == true_classes[indices, :])
  40. mask = mask.sum(1).astype(np.bool)
  41. # print('mask:', mask)
  42. indices = indices[mask]
  43. # result[indices] = 0
  44. return torch.tensor(result)
  45. def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
  46. Tuple[torch.Tensor, torch.Tensor]:
  47. if adj_mat.is_sparse:
  48. adj_mat = adj_mat.coalesce()
  49. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
  50. device=adj_mat.device)
  51. degrees = degrees.index_add(0, adj_mat.indices()[1],
  52. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
  53. device=adj_mat.device))
  54. edges_pos = adj_mat.indices().transpose(0, 1)
  55. else:
  56. degrees = adj_mat.sum(0)
  57. edges_pos = torch.nonzero(adj_mat, as_tuple=False)
  58. return edges_pos, degrees
  59. def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
  60. if not isinstance(adj_mat, torch.Tensor):
  61. raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
  62. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  63. neg_neighbors = fixed_unigram_candidate_sampler(
  64. edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device)
  65. edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1),
  66. neg_neighbors.view(-1, 1) ], 1)
  67. adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
  68. values=torch.ones(len(edges_neg)), size=adj_mat.shape,
  69. dtype=adj_mat.dtype, device=adj_mat.device)
  70. adj_mat_neg = adj_mat_neg.coalesce()
  71. indices = adj_mat_neg.indices()
  72. adj_mat_neg = torch.sparse_coo_tensor(indices,
  73. torch.ones(indices.shape[1]), adj_mat.shape,
  74. dtype=adj_mat.dtype, device=adj_mat.device)
  75. adj_mat_neg = adj_mat_neg.coalesce()
  76. return adj_mat_neg
  77. def negative_sample_data(data: Data) -> Data:
  78. new_edge_types = {}
  79. res = Data()
  80. for vt in data.vertex_types:
  81. res.add_vertex_type(vt.name, vt.count)
  82. for key, et in data.edge_types.items():
  83. adjacency_matrices_neg = []
  84. for adj_mat in et.adjacency_matrices:
  85. adj_mat_neg = negative_sample_adj_mat(adj_mat)
  86. adjacency_matrices_neg.append(adj_mat_neg)
  87. res.add_edge_type(et.name,
  88. et.vertex_type_row, et.vertex_type_column,
  89. adjacency_matrices_neg, et.decoder_factory)
  90. #new_et = EdgeType(et.name, et.vertex_type_row,
  91. # et.vertex_type_column, adjacency_matrices_neg,
  92. # et.decoder_factory, et.total_connectivity)
  93. #new_edge_types[key] = new_et
  94. #res = Data(data.vertex_types, new_edge_types)
  95. return res