IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
4.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from .data import Data
  7. from .trainprep import PreparedData, \
  8. TrainValTest
  9. from typing import Type, \
  10. List, \
  11. Callable, \
  12. Union, \
  13. Dict, \
  14. Tuple
  15. from .decode import DEDICOMDecoder
  16. from dataclasses import dataclass
  17. import time
  18. from .databatch import BatchedDataPointer
  19. @dataclass
  20. class RelationPredictions(object):
  21. edges_pos: TrainValTest
  22. edges_neg: TrainValTest
  23. edges_back_pos: TrainValTest
  24. edges_back_neg: TrainValTest
  25. @dataclass
  26. class RelationFamilyPredictions(object):
  27. relation_types: List[RelationPredictions]
  28. @dataclass
  29. class Predictions(object):
  30. relation_families: List[RelationFamilyPredictions]
  31. class DecodeLayer(torch.nn.Module):
  32. def __init__(self,
  33. input_dim: List[int],
  34. data: PreparedData,
  35. keep_prob: float = 1.,
  36. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  37. batched_data_pointer: BatchedDataPointer = None,
  38. **kwargs) -> None:
  39. super().__init__(**kwargs)
  40. if not isinstance(input_dim, list):
  41. raise TypeError('input_dim must be a List')
  42. if len(input_dim) != len(data.node_types):
  43. raise ValueError('input_dim must have length equal to num_node_types')
  44. if not all([ a == input_dim[0] for a in input_dim ]):
  45. raise ValueError('All elements of input_dim must have the same value')
  46. if not isinstance(data, PreparedData):
  47. raise TypeError('data must be an instance of PreparedData')
  48. if batched_data_pointer is not None and \
  49. not isinstance(batched_data_pointer, BatchedDataPointer):
  50. raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer')
  51. # if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data):
  52. # raise ValueError('batched_data_pointer must be compatible with data')
  53. self.input_dim = input_dim[0]
  54. self.output_dim = 1
  55. self.data = data
  56. self.keep_prob = keep_prob
  57. self.activation = activation
  58. self.batched_data_pointer = batched_data_pointer
  59. self.decoders = None
  60. self.build()
  61. def build(self) -> None:
  62. self.decoders = torch.nn.ModuleList()
  63. for fam in self.data.relation_families:
  64. dec = fam.decoder_class(self.input_dim, len(fam.relation_types),
  65. self.keep_prob, self.activation)
  66. self.decoders.append(dec)
  67. def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec):
  68. start_time = time.time()
  69. pred = []
  70. for p in edge_list_attr_names:
  71. tvt = []
  72. for t in ['train', 'val', 'test']:
  73. # print('r:', r)
  74. edges = getattr(getattr(r, p), t)
  75. inputs_row = last_layer_repr[row][edges[:, 0]]
  76. inputs_column = last_layer_repr[column][edges[:, 1]]
  77. tvt.append(dec(inputs_row, inputs_column, k))
  78. tvt = TrainValTest(*tvt)
  79. pred.append(tvt)
  80. # print('DecodeLayer._get_tvt() took:', time.time() - start_time)
  81. return pred
  82. def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]:
  83. t = time.time()
  84. res = []
  85. data = self.batched_data_pointer.batched_data \
  86. if self.batched_data_pointer is not None \
  87. else self.data
  88. for i, fam in enumerate(data.relation_families):
  89. fam_pred = []
  90. for k, r in enumerate(fam.relation_types):
  91. pred = []
  92. pred += self._get_tvt(r, ['edges_pos', 'edges_neg'],
  93. r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i])
  94. pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'],
  95. r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i])
  96. pred = RelationPredictions(*pred)
  97. fam_pred.append(pred)
  98. fam_pred = RelationFamilyPredictions(fam_pred)
  99. res.append(fam_pred)
  100. res = Predictions(res)
  101. # print('DecodeLayer.forward() took', time.time() - t)
  102. return res