IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
5.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import torch
  6. from ..weights import init_glorot
  7. from ..dropout import dropout
  8. class DEDICOMDecoder(torch.nn.Module):
  9. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  10. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  11. activation=torch.sigmoid, **kwargs):
  12. super().__init__(**kwargs)
  13. self.input_dim = input_dim
  14. self.num_relation_types = num_relation_types
  15. self.drop_prob = drop_prob
  16. self.activation = activation
  17. self.global_interaction = init_glorot(input_dim, input_dim)
  18. self.local_variation = [
  19. torch.flatten(init_glorot(input_dim, 1)) \
  20. for _ in range(num_relation_types)
  21. ]
  22. def forward(self, inputs_row, inputs_col):
  23. outputs = []
  24. for k in range(self.num_relation_types):
  25. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  26. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  27. relation = torch.diag(self.local_variation[k])
  28. product1 = torch.mm(inputs_row, relation)
  29. product2 = torch.mm(product1, self.global_interaction)
  30. product3 = torch.mm(product2, relation)
  31. rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
  32. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  33. rec = torch.flatten(rec)
  34. outputs.append(self.activation(rec))
  35. return outputs
  36. class DistMultDecoder(torch.nn.Module):
  37. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  38. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  39. activation=torch.sigmoid, **kwargs):
  40. super().__init__(**kwargs)
  41. self.input_dim = input_dim
  42. self.num_relation_types = num_relation_types
  43. self.drop_prob = drop_prob
  44. self.activation = activation
  45. self.relation = [
  46. torch.flatten(init_glorot(input_dim, 1)) \
  47. for _ in range(num_relation_types)
  48. ]
  49. def forward(self, inputs_row, inputs_col):
  50. outputs = []
  51. for k in range(self.num_relation_types):
  52. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  53. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  54. relation = torch.diag(self.relation[k])
  55. intermediate_product = torch.mm(inputs_row, relation)
  56. rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
  57. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  58. rec = torch.flatten(rec)
  59. outputs.append(self.activation(rec))
  60. return outputs
  61. class BilinearDecoder(torch.nn.Module):
  62. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  63. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  64. activation=torch.sigmoid, **kwargs):
  65. super().__init__(**kwargs)
  66. self.input_dim = input_dim
  67. self.num_relation_types = num_relation_types
  68. self.drop_prob = drop_prob
  69. self.activation = activation
  70. self.relation = [
  71. init_glorot(input_dim, input_dim) \
  72. for _ in range(num_relation_types)
  73. ]
  74. def forward(self, inputs_row, inputs_col):
  75. outputs = []
  76. for k in range(self.num_relation_types):
  77. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  78. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  79. intermediate_product = torch.mm(inputs_row, self.relation[k])
  80. rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
  81. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  82. rec = torch.flatten(rec)
  83. outputs.append(self.activation(rec))
  84. return outputs
  85. class InnerProductDecoder(torch.nn.Module):
  86. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  87. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  88. activation=torch.sigmoid, **kwargs):
  89. super().__init__(**kwargs)
  90. self.input_dim = input_dim
  91. self.num_relation_types = num_relation_types
  92. self.drop_prob = drop_prob
  93. self.activation = activation
  94. def forward(self, inputs_row, inputs_col):
  95. outputs = []
  96. for k in range(self.num_relation_types):
  97. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  98. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  99. rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
  100. inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
  101. rec = torch.flatten(rec)
  102. outputs.append(self.activation(rec))
  103. return outputs