IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

103 行
3.3KB

  1. from .declayer import Predictions
  2. import torch
  3. from dataclasses import dataclass
  4. from .trainprep import PreparedData
  5. from typing import Tuple
  6. @dataclass
  7. class FlatPredictions(object):
  8. predictions: torch.Tensor
  9. truth: torch.Tensor
  10. part_type: str
  11. def flatten_predictions(pred: Predictions, part_type: str = 'train'):
  12. if not isinstance(pred, Predictions):
  13. raise TypeError('pred must be an instance of Predictions')
  14. if part_type not in ['train', 'val', 'test']:
  15. raise ValueError('part_type must be set to train, val or test')
  16. edge_types = [('edges_pos', 1), ('edges_neg', 0),
  17. ('edges_back_pos', 1), ('edges_back_neg', 0)]
  18. input = []
  19. target = []
  20. for fam in pred.relation_families:
  21. for rel in fam.relation_types:
  22. for (et, tgt) in edge_types:
  23. edge_pred = getattr(getattr(rel, et), part_type)
  24. input.append(edge_pred)
  25. target.append(torch.ones_like(edge_pred) * tgt)
  26. input = torch.cat(input)
  27. target = torch.cat(target)
  28. return FlatPredictions(input, target, part_type)
  29. @dataclass
  30. class BatchIndices(object):
  31. indices: torch.Tensor
  32. part_type: str
  33. def gather_batch_indices(pred: FlatPredictions,
  34. indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]:
  35. if not isinstance(pred, FlatPredictions):
  36. raise TypeError('pred must be an instance of FlatPredictions')
  37. if not isinstance(indices, BatchIndices):
  38. raise TypeError('indices must be an instance of BatchIndices')
  39. if pred.part_type != indices.part_type:
  40. raise ValueError('part_type must be the same in pred and indices')
  41. return (pred.predictions[indices.indices],
  42. pred.truth[indices.indices])
  43. class PredictionsBatch(object):
  44. def __init__(self, prep_d: PreparedData, part_type: str = 'train',
  45. batch_size: int = 100, shuffle: bool = False,
  46. generator: torch.Generator = None) -> None:
  47. if not isinstance(prep_d, PreparedData):
  48. raise TypeError('prep_d must be an instance of PreparedData')
  49. if part_type not in ['train', 'val', 'test']:
  50. raise ValueError('part_type must be set to train, val or test')
  51. batch_size = int(batch_size)
  52. shuffle = bool(shuffle)
  53. if generator is not None and not isinstance(generator, torch.Generator):
  54. raise TypeError('generator must be an instance of torch.Generator')
  55. self.prep_d = prep_d
  56. self.part_type = part_type
  57. self.batch_size = batch_size
  58. self.shuffle = shuffle
  59. self.generator = generator or torch.default_generator
  60. count = 0
  61. for fam in prep_d.relation_families:
  62. for rel in fam.relation_types:
  63. for et in ['edges_pos', 'edges_neg',
  64. 'edges_back_pos', 'edges_back_neg']:
  65. count += len(getattr(getattr(rel, et), part_type))
  66. self.total_edge_count = count
  67. def __iter__(self):
  68. values = torch.arange(self.total_edge_count)
  69. if self.shuffle:
  70. perm = torch.randperm(len(values))
  71. values = values[perm]
  72. for i in range(0, len(values), self.batch_size):
  73. yield BatchIndices(values[i:i+self.batch_size], self.part_type)