IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

120 Zeilen
4.6KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import PreparedData
  3. from icosagon.decode import DEDICOMDecoder, \
  4. DistMultDecoder, \
  5. BilinearDecoder, \
  6. InnerProductDecoder
  7. from icosagon.dropout import dropout
  8. import torch
  9. from typing import List, \
  10. Callable, \
  11. Union
  12. '''
  13. Let's say that I have dense latent representations row and col.
  14. Then let's take relation matrix rel in a list of relations REL.
  15. A single computation currenty looks like this:
  16. (((row * rel) * glob) * rel) * col
  17. Shouldn't then this basically work:
  18. prod1 = torch.matmul(row, REL)
  19. prod2 = torch.matmul(prod1, glob)
  20. prod3 = torch.matmul(prod2, REL)
  21. res = torch.matmul(prod3, col)
  22. res = activation(res)
  23. res should then have shape: (num_relations, num_rows, num_columns)
  24. '''
  25. def convert_decoder(dec):
  26. if isinstance(dec, DEDICOMDecoder):
  27. global_interaction = dec.global_interaction
  28. local_variation = map(torch.diag, dec.local_variation)
  29. elif isinstance(dec, DistMultDecoder):
  30. global_interaction = torch.eye(dec.input_dim, dec.input_dim)
  31. local_variation = map(torch.diag, dec.relation)
  32. elif isinstance(dec, BilinearDecoder):
  33. global_interaction = torch.eye(dec.input_dim, dec.input_dim)
  34. local_variation = dec.relation
  35. elif isinstance(dec, InnerProductDecoder):
  36. global_interaction = torch.eye(dec.input_dim, dec.input_dim)
  37. local_variation = torch.eye(dec.input_dim, dec.input_dim)
  38. local_variation = [ local_variation ] * dec.num_relation_types
  39. else:
  40. raise TypeError('Unknown decoder type in convert_decoder()')
  41. if not isinstance(local_variation, torch.Tensor):
  42. local_variation = map(lambda a: a.view(1, *a.shape), local_variation)
  43. local_variation = torch.cat(list(local_variation))
  44. return (global_interaction, local_variation)
  45. class BulkDecodeLayer(torch.nn.Module):
  46. def __init__(self,
  47. input_dim: List[int],
  48. data: Union[Data, PreparedData],
  49. keep_prob: float = 1.,
  50. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  51. **kwargs) -> None:
  52. super().__init__(**kwargs)
  53. self._check_params(input_dim, data)
  54. self.input_dim = input_dim[0]
  55. self.data = data
  56. self.keep_prob = keep_prob
  57. self.activation = activation
  58. self.decoders = None
  59. self.global_interaction = None
  60. self.local_variation = None
  61. self.build()
  62. def build(self) -> None:
  63. self.decoders = torch.nn.ModuleList()
  64. self.global_interaction = torch.nn.ParameterList()
  65. self.local_variation = torch.nn.ParameterList()
  66. for fam in self.data.relation_families:
  67. dec = fam.decoder_class(self.input_dim,
  68. len(fam.relation_types),
  69. self.keep_prob,
  70. self.activation)
  71. self.decoders.append(dec)
  72. global_interaction, local_variation = convert_decoder(dec)
  73. self.global_interaction.append(torch.nn.Parameter(global_interaction))
  74. self.local_variation.append(torch.nn.Parameter(local_variation))
  75. def forward(self, last_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
  76. res = []
  77. for i, fam in enumerate(self.data.relation_families):
  78. repr_row = last_layer_repr[fam.node_type_row]
  79. repr_column = last_layer_repr[fam.node_type_column]
  80. repr_row = dropout(repr_row, keep_prob=self.keep_prob)
  81. repr_column = dropout(repr_column, keep_prob=self.keep_prob)
  82. prod_1 = torch.matmul(repr_row, self.local_variation[i])
  83. print(f'local_variation[{i}].shape: {self.local_variation[i].shape}')
  84. prod_2 = torch.matmul(prod_1, self.global_interaction[i])
  85. prod_3 = torch.matmul(prod_2, self.local_variation[i])
  86. pred = torch.matmul(prod_3, repr_column.transpose(0, 1))
  87. res.append(pred)
  88. return res
  89. @staticmethod
  90. def _check_params(input_dim, data):
  91. if not isinstance(input_dim, list):
  92. raise TypeError('input_dim must be a list')
  93. if len(input_dim) != len(data.node_types):
  94. raise ValueError('input_dim must have length equal to num_node_types')
  95. if not all([ a == input_dim[0] for a in input_dim ]):
  96. raise ValueError('All elements of input_dim must have the same value')
  97. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  98. raise TypeError('data must be an instance of Data or PreparedData')