IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.9KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .layer import Layer
  6. import torch
  7. from ..convolve import DropoutGraphConvActivation
  8. from ..data import Data
  9. from typing import List, \
  10. Union, \
  11. Callable
  12. from collections import defaultdict
  13. class DecagonLayer(Layer):
  14. def __init__(self,
  15. data: Data,
  16. previous_layer: Layer,
  17. output_dim: Union[int, List[int]],
  18. keep_prob: float = 1.,
  19. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  20. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  21. **kwargs):
  22. if not isinstance(output_dim, list):
  23. output_dim = [ output_dim ] * len(data.node_types)
  24. super().__init__(output_dim, is_sparse=False, **kwargs)
  25. self.data = data
  26. self.previous_layer = previous_layer
  27. self.input_dim = previous_layer.output_dim
  28. self.keep_prob = keep_prob
  29. self.rel_activation = rel_activation
  30. self.layer_activation = layer_activation
  31. self.next_layer_repr = None
  32. self.build()
  33. def build(self):
  34. self.next_layer_repr = defaultdict(list)
  35. for (nt_row, nt_col), relation_types in self.data.relation_types.items():
  36. row_convs = []
  37. col_convs = []
  38. for rel in relation_types:
  39. conv = DropoutGraphConvActivation(self.input_dim[nt_col],
  40. self.output_dim[nt_row], rel.adjacency_matrix,
  41. self.keep_prob, self.rel_activation)
  42. row_convs.append(conv)
  43. if nt_row == nt_col:
  44. continue
  45. conv = DropoutGraphConvActivation(self.input_dim[nt_row],
  46. self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
  47. self.keep_prob, self.rel_activation)
  48. col_convs.append(conv)
  49. self.next_layer_repr[nt_row].append((row_convs, nt_col))
  50. if nt_row == nt_col:
  51. continue
  52. self.next_layer_repr[nt_col].append((col_convs, nt_row))
  53. def __call__(self):
  54. prev_layer_repr = self.previous_layer()
  55. next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
  56. print('next_layer_repr:', next_layer_repr)
  57. for i in range(len(self.data.node_types)):
  58. for convs, neighbor_type in self.next_layer_repr[i]:
  59. convs = [ conv(prev_layer_repr[neighbor_type]) \
  60. for conv in convs ]
  61. convs = sum(convs)
  62. convs = torch.nn.functional.normalize(convs, p=2, dim=1)
  63. next_layer_repr[i].append(convs)
  64. next_layer_repr[i] = sum(next_layer_repr[i])
  65. next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
  66. print('next_layer_repr:', next_layer_repr)
  67. return next_layer_repr