IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
4.6KB

  1. from icosagon.data import Data
  2. from icosagon.trainprep import PreparedData
  3. from icosagon.decode import DEDICOMDecoder, \
  4. DistMultDecoder, \
  5. BilinearDecoder, \
  6. InnerProductDecoder
  7. from icosagon.dropout import dropout
  8. import torch
  9. from typing import List, \
  10. Callable, \
  11. Union
  12. '''
  13. Let's say that I have dense latent representations row and col.
  14. Then let's take relation matrix rel in a list of relations REL.
  15. A single computation currenty looks like this:
  16. (((row * rel) * glob) * rel) * col
  17. Shouldn't then this basically work:
  18. prod1 = torch.matmul(row, REL)
  19. prod2 = torch.matmul(prod1, glob)
  20. prod3 = torch.matmul(prod2, REL)
  21. res = torch.matmul(prod3, col)
  22. res = activation(res)
  23. res should then have shape: (num_relations, num_rows, num_columns)
  24. '''
  25. def convert_decoder(dec):
  26. if isinstance(dec, DEDICOMDecoder):
  27. global_interaction = dec.global_interaction
  28. local_variation = map(torch.diag, dec.local_variation)
  29. elif isinstance(dec, DistMultDecoder):
  30. global_interaction = torch.eye(dec.input_dim, dec.input_dim)
  31. local_variation = map(torch.diag, dec.relation)
  32. elif isinstance(dec, BilinearDecoder):
  33. global_interaction = torch.eye(dec.input_dim, dec.input_dim)
  34. local_variation = dec.relation
  35. elif isinstance(dec, InnerProductDecoder):
  36. global_interaction = torch.eye(dec.input_dim, dec.input_dim)
  37. local_variation = torch.eye(dec.input_dim, dec.input_dim)
  38. local_variation = [ local_variation ] * dec.num_relation_types
  39. else:
  40. raise TypeError('Unknown decoder type in convert_decoder()')
  41. if not isinstance(local_variation, torch.Tensor):
  42. local_variation = map(lambda a: a.view(1, *a.shape), local_variation)
  43. local_variation = torch.cat(list(local_variation))
  44. return (global_interaction, local_variation)
  45. class BulkDecodeLayer(torch.nn.Module):
  46. def __init__(self,
  47. input_dim: List[int],
  48. data: Union[Data, PreparedData],
  49. keep_prob: float = 1.,
  50. activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
  51. **kwargs) -> None:
  52. super().__init__(**kwargs)
  53. self._check_params(input_dim, data)
  54. self.input_dim = input_dim[0]
  55. self.data = data
  56. self.keep_prob = keep_prob
  57. self.activation = activation
  58. self.decoders = None
  59. self.global_interaction = None
  60. self.local_variation = None
  61. self.build()
  62. def build(self) -> None:
  63. self.decoders = torch.nn.ModuleList()
  64. self.global_interaction = torch.nn.ParameterList()
  65. self.local_variation = torch.nn.ParameterList()
  66. for fam in self.data.relation_families:
  67. dec = fam.decoder_class(self.input_dim,
  68. len(fam.relation_types),
  69. self.keep_prob,
  70. self.activation)
  71. self.decoders.append(dec)
  72. global_interaction, local_variation = convert_decoder(dec)
  73. self.global_interaction.append(torch.nn.Parameter(global_interaction))
  74. self.local_variation.append(torch.nn.Parameter(local_variation))
  75. def forward(self, last_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
  76. res = []
  77. for i, fam in enumerate(self.data.relation_families):
  78. repr_row = last_layer_repr[fam.node_type_row]
  79. repr_column = last_layer_repr[fam.node_type_column]
  80. repr_row = dropout(repr_row, keep_prob=self.keep_prob)
  81. repr_column = dropout(repr_column, keep_prob=self.keep_prob)
  82. prod_1 = torch.matmul(repr_row, self.local_variation[i])
  83. print(f'local_variation[{i}].shape: {self.local_variation[i].shape}')
  84. prod_2 = torch.matmul(prod_1, self.global_interaction[i])
  85. prod_3 = torch.matmul(prod_2, self.local_variation[i])
  86. pred = torch.matmul(prod_3, repr_column.transpose(0, 1))
  87. res.append(pred)
  88. return res
  89. @staticmethod
  90. def _check_params(input_dim, data):
  91. if not isinstance(input_dim, list):
  92. raise TypeError('input_dim must be a list')
  93. if len(input_dim) != len(data.node_types):
  94. raise ValueError('input_dim must have length equal to num_node_types')
  95. if not all([ a == input_dim[0] for a in input_dim ]):
  96. raise ValueError('All elements of input_dim must have the same value')
  97. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  98. raise TypeError('data must be an instance of Data or PreparedData')