IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

75 рядки
2.8KB

  1. from .layer import Layer
  2. import torch
  3. from ..convolve import DropoutGraphConvActivation
  4. from ..data import Data
  5. from typing import List, \
  6. Union, \
  7. Callable
  8. from collections import defaultdict
  9. class DecagonLayer(Layer):
  10. def __init__(self,
  11. data: Data,
  12. previous_layer: Layer,
  13. output_dim: Union[int, List[int]],
  14. keep_prob: float = 1.,
  15. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  16. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  17. **kwargs):
  18. if not isinstance(output_dim, list):
  19. output_dim = [ output_dim ] * len(data.node_types)
  20. super().__init__(output_dim, is_sparse=False, **kwargs)
  21. self.data = data
  22. self.previous_layer = previous_layer
  23. self.input_dim = previous_layer.output_dim
  24. self.keep_prob = keep_prob
  25. self.rel_activation = rel_activation
  26. self.layer_activation = layer_activation
  27. self.next_layer_repr = None
  28. self.build()
  29. def build(self):
  30. self.next_layer_repr = defaultdict(list)
  31. for (nt_row, nt_col), relation_types in self.data.relation_types.items():
  32. row_convs = []
  33. col_convs = []
  34. for rel in relation_types:
  35. conv = DropoutGraphConvActivation(self.input_dim[nt_col],
  36. self.output_dim[nt_row], rel.adjacency_matrix,
  37. self.keep_prob, self.rel_activation)
  38. row_convs.append(conv)
  39. if nt_row == nt_col:
  40. continue
  41. conv = DropoutGraphConvActivation(self.input_dim[nt_row],
  42. self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
  43. self.keep_prob, self.rel_activation)
  44. col_convs.append(conv)
  45. self.next_layer_repr[nt_row].append((row_convs, nt_col))
  46. if nt_row == nt_col:
  47. continue
  48. self.next_layer_repr[nt_col].append((col_convs, nt_row))
  49. def __call__(self):
  50. prev_layer_repr = self.previous_layer()
  51. next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
  52. print('next_layer_repr:', next_layer_repr)
  53. for i in range(len(self.data.node_types)):
  54. for convs, neighbor_type in self.next_layer_repr[i]:
  55. convs = [ conv(prev_layer_repr[neighbor_type]) \
  56. for conv in convs ]
  57. convs = sum(convs)
  58. convs = torch.nn.functional.normalize(convs, p=2, dim=1)
  59. next_layer_repr[i].append(convs)
  60. next_layer_repr[i] = sum(next_layer_repr[i])
  61. next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
  62. print('next_layer_repr:', next_layer_repr)
  63. return next_layer_repr