IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.3KB

  1. from .declayer import Predictions
  2. import torch
  3. from dataclasses import dataclass
  4. from .trainprep import PreparedData
  5. from typing import Tuple
  6. @dataclass
  7. class FlatPredictions(object):
  8. predictions: torch.Tensor
  9. truth: torch.Tensor
  10. part_type: str
  11. def flatten_predictions(pred: Predictions, part_type: str = 'train'):
  12. if not isinstance(pred, Predictions):
  13. raise TypeError('pred must be an instance of Predictions')
  14. if part_type not in ['train', 'val', 'test']:
  15. raise ValueError('part_type must be set to train, val or test')
  16. edge_types = [('edges_pos', 1), ('edges_neg', 0),
  17. ('edges_back_pos', 1), ('edges_back_neg', 0)]
  18. input = []
  19. target = []
  20. for fam in pred.relation_families:
  21. for rel in fam.relation_types:
  22. for (et, tgt) in edge_types:
  23. edge_pred = getattr(getattr(rel, et), part_type)
  24. input.append(edge_pred)
  25. target.append(torch.ones_like(edge_pred) * tgt)
  26. input = torch.cat(input)
  27. target = torch.cat(target)
  28. return FlatPredictions(input, target, part_type)
  29. @dataclass
  30. class BatchIndices(object):
  31. indices: torch.Tensor
  32. part_type: str
  33. def gather_batch_indices(pred: FlatPredictions,
  34. indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]:
  35. if not isinstance(pred, FlatPredictions):
  36. raise TypeError('pred must be an instance of FlatPredictions')
  37. if not isinstance(indices, BatchIndices):
  38. raise TypeError('indices must be an instance of BatchIndices')
  39. if pred.part_type != indices.part_type:
  40. raise ValueError('part_type must be the same in pred and indices')
  41. return (pred.predictions[indices.indices],
  42. pred.truth[indices.indices])
  43. class PredictionsBatch(object):
  44. def __init__(self, prep_d: PreparedData, part_type: str = 'train',
  45. batch_size: int = 100, shuffle: bool = False,
  46. generator: torch.Generator = None) -> None:
  47. if not isinstance(prep_d, PreparedData):
  48. raise TypeError('prep_d must be an instance of PreparedData')
  49. if part_type not in ['train', 'val', 'test']:
  50. raise ValueError('part_type must be set to train, val or test')
  51. batch_size = int(batch_size)
  52. shuffle = bool(shuffle)
  53. if generator is not None and not isinstance(generator, torch.Generator):
  54. raise TypeError('generator must be an instance of torch.Generator')
  55. self.prep_d = prep_d
  56. self.part_type = part_type
  57. self.batch_size = batch_size
  58. self.shuffle = shuffle
  59. self.generator = generator or torch.default_generator
  60. count = 0
  61. for fam in prep_d.relation_families:
  62. for rel in fam.relation_types:
  63. for et in ['edges_pos', 'edges_neg',
  64. 'edges_back_pos', 'edges_back_neg']:
  65. count += len(getattr(getattr(rel, et), part_type))
  66. self.total_edge_count = count
  67. def __iter__(self):
  68. values = torch.arange(self.total_edge_count)
  69. if self.shuffle:
  70. perm = torch.randperm(len(values))
  71. values = values[perm]
  72. for i in range(0, len(values), self.batch_size):
  73. yield BatchIndices(values[i:i+self.batch_size], self.part_type)