IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

139 行
4.5KB

  1. import torch
  2. from typing import List
  3. from .trainprep import PreparedData
  4. from dataclasses import dataclass
  5. import random
  6. from collections import defaultdict
  7. @dataclass
  8. class TrainingBatch(object):
  9. relation_family_index: int
  10. relation_type_index: int
  11. node_type_row: int
  12. node_type_column: int
  13. edges: torch.Tensor
  14. class FastBatcher(object):
  15. def __init__(self,
  16. prep_d: PreparedData,
  17. batch_size: int) -> None:
  18. if not isinstance(prep_d, PreparedData):
  19. raise TypeError('prep_d must be an instance of PreparedData')
  20. self.prep_d = prep_d
  21. self.batch_size = int(batch_size)
  22. self.edges = None
  23. self.build()
  24. def build(self):
  25. self.edges = []
  26. for fam_idx, fam in enumerate(self.prep_d.relation_families):
  27. edges = []
  28. targets = []
  29. edges_back = []
  30. targets_back = []
  31. for rel_idx, rel in enumerate(fam.relation_types):
  32. edges.append(rel.edges_pos.train)
  33. edges.append(rel.edges_neg.train)
  34. targets.append(torch.ones(len(rel.edges_pos.train)))
  35. targets.append(torch.zeros(len(rel.edges_neg.train)))
  36. edges_back.append(rel.edges_back_pos.train)
  37. edges_back.append(rel.edges_back_neg.train)
  38. targets_back.apend(torch.zeros(len(rel.edges_back_pos.train)))
  39. targets_back.apend(torch.zeros(len(rel.edges_back_neg.train)))
  40. edges = torch.cat(edges)
  41. targets = torch.cat(targets)
  42. edges_back = torch.cat(edges_back)
  43. targets_back = torch.cat(targets_back)
  44. order = torch.randperm(len(edges))
  45. edges = edges[order]
  46. targets = targets[order]
  47. order_back = torch.randperm(len(edges_back))
  48. edges_back = edges_back[order_back]
  49. targets_back = targets_back[order_back]
  50. self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False,
  51. 'edges': edges, 'targets': targets, 'ofs': 0})
  52. self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True,
  53. 'edges': edges_back, 'targets': targets_back, 'ofs': 0})
  54. def __iter__(self):
  55. while True:
  56. edges = [ e for e in self.edges \
  57. if e['ofs'] < len(e['edges']) ]
  58. # TODO: need to finish this
  59. def __iter_old__(self):
  60. edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']
  61. offsets = {}
  62. orders = {}
  63. done = {}
  64. for fam_idx, fam in enumerate(self.prep_d.relation_families):
  65. for rel_idx, rel in enumerate(fam.relation_types):
  66. for et in edge_types:
  67. done[fam_idx, rel_idx, et] = False
  68. while True:
  69. fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item()
  70. fam = self.prep_d.relation_families[fam_idx]
  71. rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item()
  72. rel = fam.relation_types[rel_idx]
  73. et = random.choice(edge_types)
  74. edges = getattr(rel, et).train
  75. key = (fam_idx, rel_idx, et)
  76. if key not in orders:
  77. orders[key] = torch.randperm(len(edges))
  78. offsets[key] = 0
  79. ord = orders[key]
  80. ofs = offsets[key]
  81. nt_row = rel.node_type_row
  82. nt_col = rel.node_type_column
  83. if 'back' in et:
  84. nt_row, nt_col = nt_col, nt_row
  85. if ofs < len(edges):
  86. offsets[key] += self.batch_size
  87. ord = ord[ofs:ofs+self.batch_size]
  88. edges = edges[ord]
  89. yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges)
  90. else:
  91. done[key] = True
  92. for fam in self.prep_d.relation_families:
  93. edges = []
  94. for rel in fam.relation_types:
  95. edges.append(rel.edges_pos.train)
  96. edges.append(rel.edges_back_pos.train)
  97. edges.append(rel.edges_neg.train)
  98. edges.append(rel.edges_back_neg.train)
  99. edges = torch.cat(e)
  100. class FastDecLayer(torch.nn.Module):
  101. def __init__(self, **kwargs):
  102. super().__init__(**kwargs)
  103. def forward(self,
  104. last_layer_repr: List[torch.Tensor],
  105. training_batch: TrainingBatch):