IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.4KB

  1. from .data import Data
  2. from .model import TrainingBatch
  3. import torch
  4. from functools import reduce
  5. def _shuffle(x: torch.Tensor) -> torch.Tensor:
  6. order = torch.randperm(len(x))
  7. return x[order]
  8. def _same_data_org(pos_data: Data, neg_data: Data):
  9. if len(pos_data.vertex_types) != len(neg_data.vertex_types):
  10. return False
  11. test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
  12. and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
  13. for i in range(len(pos_data.vertex_types)) ]
  14. if not all(test):
  15. return False
  16. if not set(pos_data.edge_types.keys()) == \
  17. set(neg_data.edge_types.keys()):
  18. return False
  19. test = [ pos_data.edge_types[i].name == \
  20. neg_data.edge_types[i].name \
  21. and pos_data.edge_types[i].vertex_type_row == \
  22. neg_data.edge_types[i].vertex_type_row \
  23. and pos_data.edge_types[i].vertex_type_column == \
  24. neg_data.edge_types[i].vertex_type_column \
  25. and len(pos_data.edge_types[i].adjacency_matrices) == \
  26. len(neg_data.edge_types[i].adjacency_matrices) \
  27. for i in pos_data.edge_types.keys() ]
  28. if not all(test):
  29. return False
  30. test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
  31. len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
  32. for k in range(len(pos_data.edge_types[i])) ] \
  33. for i in pos_data.edge_types.keys() ]
  34. test = reduce(list.__add__, test)
  35. if not all(test):
  36. return False
  37. return True
  38. class DualBatcher(object):
  39. def __init__(self, pos_data: Data, neg_data: Data,
  40. batch_size: int=512, shuffle: bool=True) -> None:
  41. if not isinstance(pos_data, Data):
  42. raise TypeError('pos_data must be an instance of Data')
  43. if not isinstance(neg_data, Data):
  44. raise TypeError('neg_data must be an instance of Data')
  45. if not _same_data_org(pos_data, neg_data):
  46. raise ValueError('pos_data and neg_data must have the same organization')
  47. self.pos_data = pos_data
  48. self.neg_data = neg_data
  49. self.batch_size = int(batch_size)
  50. self.shuffle = bool(shuffle)
  51. def __iter__(self):
  52. class Batcher(object):
  53. def __init__(self, data: Data, batch_size: int=512,
  54. shuffle: bool=True) -> None:
  55. if not isinstance(data, Data):
  56. raise TypeError('data must be an instance of Data')
  57. self.data = data
  58. self.batch_size = int(batch_size)
  59. self.shuffle = bool(shuffle)
  60. def __iter__(self) -> TrainingBatch:
  61. edge_types = list(self.data.edge_types.values())
  62. edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
  63. for adj_mat in et.adjacency_matrices ] \
  64. for et in edge_types ]
  65. if self.shuffle:
  66. edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
  67. for edge_lst in edge_lists ]
  68. offsets = [ [ 0 ] * len(et.adjacency_matrices) \
  69. for et in edge_types ]
  70. while True:
  71. candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
  72. if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
  73. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
  74. if len(candidates) == 0:
  75. break
  76. edge_idx = torch.randint(0, len(candidates), (1,)).item()
  77. edge_idx = candidates[edge_idx]
  78. candidates = [ rel_idx \
  79. for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
  80. if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
  81. rel_idx = torch.randint(0, len(candidates), (1,)).item()
  82. rel_idx = candidates[rel_idx]
  83. lst = edge_lists[edge_idx][rel_idx]
  84. et = edge_types[edge_idx]
  85. ofs = offsets[edge_idx][rel_idx]
  86. lst = lst[ofs:ofs+self.batch_size]
  87. offsets[edge_idx][rel_idx] += self.batch_size
  88. b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
  89. rel_idx, lst, torch.full((len(lst),), self.data.target_value,
  90. dtype=torch.float32))
  91. yield b