IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
5.6KB

  1. from .fastmodel import FastModel
  2. from .trainprep import PreparedData
  3. import torch
  4. from typing import Callable
  5. from types import FunctionType
  6. import time
  7. import random
  8. class FastBatcher(object):
  9. def __init__(self, prep_d: PreparedData, batch_size: int,
  10. shuffle: bool, generator: torch.Generator,
  11. part_type: str) -> None:
  12. if not isinstance(prep_d, PreparedData):
  13. raise TypeError('prep_d must be an instance of PreparedData')
  14. if not isinstance(generator, torch.Generator):
  15. raise TypeError('generator must be an instance of torch.Generator')
  16. if part_type not in ['train', 'val', 'test']:
  17. raise ValueError('part_type must be set to train, val or test')
  18. self.prep_d = prep_d
  19. self.batch_size = int(batch_size)
  20. self.shuffle = bool(shuffle)
  21. self.generator = generator
  22. self.part_type = part_type
  23. self.edges = None
  24. self.targets = None
  25. self.build()
  26. def build(self):
  27. self.edges = []
  28. self.targets = []
  29. for fam in self.prep_d.relation_families:
  30. edges = []
  31. targets = []
  32. for i, rel in enumerate(fam.relation_types):
  33. edges_pos = getattr(rel.edges_pos, self.part_type)
  34. edges_neg = getattr(rel.edges_neg, self.part_type)
  35. edges_back_pos = getattr(rel.edges_back_pos, self.part_type)
  36. edges_back_neg = getattr(rel.edges_back_neg, self.part_type)
  37. e = torch.cat([ edges_pos,
  38. torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ])
  39. e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1)
  40. t = torch.ones(len(e))
  41. edges.append(e)
  42. targets.append(t)
  43. e = torch.cat([ edges_neg,
  44. torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ])
  45. e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1)
  46. t = torch.zeros(len(e))
  47. edges.append(e)
  48. targets.append(t)
  49. edges = torch.cat(edges)
  50. targets = torch.cat(targets)
  51. self.edges.append(edges)
  52. self.targets.append(targets)
  53. # print(self.edges)
  54. # print(self.targets)
  55. if self.shuffle:
  56. self.shuffle_families()
  57. def shuffle_families(self):
  58. for i in range(len(self.edges)):
  59. edges = self.edges[i]
  60. targets = self.targets[i]
  61. order = torch.randperm(len(edges), generator=self.generator)
  62. self.edges[i] = edges[order]
  63. self.targets[i] = targets[order]
  64. def __iter__(self):
  65. offsets = [ 0 for _ in self.edges ]
  66. while True:
  67. choice = [ i for i in range(len(offsets)) \
  68. if offsets[i] < len(self.edges[i]) ]
  69. if len(choice) == 0:
  70. break
  71. fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item()
  72. ofs = offsets[fam_idx]
  73. edges = self.edges[fam_idx][ofs:ofs + self.batch_size]
  74. targets = self.targets[fam_idx][ofs:ofs + self.batch_size]
  75. offsets[fam_idx] += self.batch_size
  76. yield (fam_idx, edges, targets)
  77. class FastLoop(object):
  78. def __init__(
  79. self,
  80. model: FastModel,
  81. lr: float = 0.001,
  82. loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
  83. torch.nn.functional.binary_cross_entropy_with_logits,
  84. batch_size: int = 100,
  85. shuffle: bool = True,
  86. generator: torch.Generator = None) -> None:
  87. self._check_params(model, loss, generator)
  88. self.model = model
  89. self.lr = float(lr)
  90. self.loss = loss
  91. self.batch_size = int(batch_size)
  92. self.shuffle = bool(shuffle)
  93. self.generator = generator or torch.default_generator
  94. self.opt = None
  95. self.build()
  96. def _check_params(self, model, loss, generator):
  97. if not isinstance(model, FastModel):
  98. raise TypeError('model must be an instance of FastModel')
  99. if not isinstance(loss, FunctionType):
  100. raise TypeError('loss must be a function')
  101. if generator is not None and not isinstance(generator, torch.Generator):
  102. raise TypeError('generator must be an instance of torch.Generator')
  103. def build(self) -> None:
  104. opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
  105. self.opt = opt
  106. def run_epoch(self):
  107. prep_d = self.model.prep_d
  108. batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size,
  109. shuffle = self.shuffle, generator=self.generator)
  110. # pred = self.model(None)
  111. # n = len(list(iter(batch)))
  112. loss_sum = 0
  113. for fam_idx, edges, targets in batcher:
  114. self.opt.zero_grad()
  115. pred = self.model(None)
  116. # process pred, get input and targets
  117. input = pred[fam_idx][edges[:, 0], edges[:, 1]]
  118. loss = self.loss(input, targets)
  119. loss.backward()
  120. self.opt.step()
  121. loss_sum += loss.detach().cpu().item()
  122. return loss_sum
  123. def train(self, max_epochs):
  124. best_loss = None
  125. best_epoch = None
  126. for i in range(max_epochs):
  127. loss = self.run_epoch()
  128. if best_loss is None or loss < best_loss:
  129. best_loss = loss
  130. best_epoch = i
  131. return loss, best_loss, best_epoch