IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

67 wiersze
2.5KB

  1. import torch
  2. from .weights import init_glorot
  3. from .dropout import dropout
  4. class DEDICOMDecoder(torch.nn.Module):
  5. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  6. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  7. activation=torch.sigmoid, **kwargs):
  8. super().__init__(**kwargs)
  9. self.input_dim = input_dim
  10. self.num_relation_types = num_relation_types
  11. self.drop_prob = drop_prob
  12. self.activation = activation
  13. self.global_interaction = init_glorot(input_dim, input_dim)
  14. self.local_variation = [
  15. torch.flatten(init_glorot(input_dim, 1)) \
  16. for _ in range(num_relation_types)
  17. ]
  18. def forward(self, inputs_row, inputs_col):
  19. outputs = []
  20. for k in range(self.num_relation_types):
  21. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  22. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  23. relation = torch.diag(self.local_variation[k])
  24. product1 = torch.mm(inputs_row, relation)
  25. product2 = torch.mm(product1, self.global_interaction)
  26. product3 = torch.mm(product2, relation)
  27. rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
  28. outputs.append(self.activation(rec))
  29. return outputs
  30. class DistMultDecoder(torch.nn.Module):
  31. """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
  32. def __init__(self, input_dim, num_relation_types, drop_prob=0.,
  33. activation=torch.sigmoid, **kwargs):
  34. super().__init__(**kwargs)
  35. self.input_dim = input_dim
  36. self.num_relation_types = num_relation_types
  37. self.drop_prob = drop_prob
  38. self.activation = activation
  39. self.relation = [
  40. torch.flatten(init_glorot(input_dim, 1)) \
  41. for _ in range(num_relation_types)
  42. ]
  43. def forward(self, inputs_row, inputs_col):
  44. outputs = []
  45. for k in range(self.num_relation_types):
  46. inputs_row = dropout(inputs_row, 1.-self.drop_prob)
  47. inputs_col = dropout(inputs_col, 1.-self.drop_prob)
  48. relation = torch.diag(self.relation[k])
  49. intermediate_product = torch.mm(inputs_row, relation)
  50. rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
  51. outputs.append(self.activation(rec))
  52. return outputs