IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer.py 5.7KB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #
  2. # This module implements a single layer of the Decagon
  3. # model. This is going to be already quite complex, as
  4. # we will be using all the graph convolutional building
  5. # blocks.
  6. #
  7. # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
  8. # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
  9. #
  10. # N{r}^{i} - set of neighbors of node i under relation r
  11. # W_{r}^(k) - relation-type specific weight matrix
  12. # h_{i}^(k) - hidden state of node i in layer k
  13. # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
  14. # of the representation in k-th layer
  15. # ϕ - activation function
  16. # c_{r}^{ij} - normalization constants
  17. # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  18. # c_{r}^{i} - normalization constants
  19. # c_{r}^{i} = 1/|N_{r}^{i}|
  20. #
  21. import torch
  22. from .convolve import DropoutGraphConvActivation
  23. from .data import Data
  24. from typing import List, \
  25. Union, \
  26. Callable
  27. from collections import defaultdict
  28. class Layer(torch.nn.Module):
  29. def __init__(self,
  30. output_dim: Union[int, List[int]],
  31. is_sparse: bool,
  32. **kwargs) -> None:
  33. super().__init__(**kwargs)
  34. self.output_dim = output_dim
  35. self.is_sparse = is_sparse
  36. class InputLayer(Layer):
  37. def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
  38. output_dim = output_dim or \
  39. list(map(lambda a: a.count, data.node_types))
  40. if not isinstance(output_dim, list):
  41. output_dim = [output_dim,] * len(data.node_types)
  42. super().__init__(output_dim, is_sparse=False, **kwargs)
  43. self.data = data
  44. self.node_reps = None
  45. self.build()
  46. def build(self) -> None:
  47. self.node_reps = []
  48. for i, nt in enumerate(self.data.node_types):
  49. reps = torch.rand(nt.count, self.output_dim[i])
  50. reps = torch.nn.Parameter(reps)
  51. self.register_parameter('node_reps[%d]' % i, reps)
  52. self.node_reps.append(reps)
  53. def forward(self) -> List[torch.nn.Parameter]:
  54. return self.node_reps
  55. def __repr__(self) -> str:
  56. s = ''
  57. s += 'GNN input layer with output_dim: %s\n' % self.output_dim
  58. s += ' # of node types: %d\n' % len(self.data.node_types)
  59. for nt in self.data.node_types:
  60. s += ' - %s (%d)\n' % (nt.name, nt.count)
  61. return s.strip()
  62. class OneHotInputLayer(Layer):
  63. def __init__(self, data: Data, **kwargs) -> None:
  64. output_dim = [ a.count for a in data.node_types ]
  65. super().__init__(output_dim, is_sparse=True, **kwargs)
  66. self.data = data
  67. self.node_reps = None
  68. self.build()
  69. def build(self) -> None:
  70. self.node_reps = []
  71. for i, nt in enumerate(self.data.node_types):
  72. reps = torch.eye(nt.count).to_sparse()
  73. reps = torch.nn.Parameter(reps)
  74. self.register_parameter('node_reps[%d]' % i, reps)
  75. self.node_reps.append(reps)
  76. def forward(self) -> List[torch.nn.Parameter]:
  77. return self.node_reps
  78. def __repr__(self) -> str:
  79. s = ''
  80. s += 'One-hot GNN input layer\n'
  81. s += ' # of node types: %d\n' % len(self.data.node_types)
  82. for nt in self.data.node_types:
  83. s += ' - %s (%d)\n' % (nt.name, nt.count)
  84. return s.strip()
  85. class DecagonLayer(Layer):
  86. def __init__(self,
  87. data: Data,
  88. previous_layer: Layer,
  89. output_dim: Union[int, List[int]],
  90. keep_prob: float = 1.,
  91. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  92. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  93. **kwargs):
  94. if not isinstance(output_dim, list):
  95. output_dim = [ output_dim ] * len(data.node_types)
  96. super().__init__(output_dim, is_sparse=False, **kwargs)
  97. self.data = data
  98. self.previous_layer = previous_layer
  99. self.input_dim = previous_layer.output_dim
  100. self.keep_prob = keep_prob
  101. self.rel_activation = rel_activation
  102. self.layer_activation = layer_activation
  103. self.next_layer_repr = None
  104. self.build()
  105. def build(self):
  106. self.next_layer_repr = defaultdict(list)
  107. for (nt_row, nt_col), relation_types in self.data.relation_types.items():
  108. for rel in relation_types:
  109. conv = DropoutGraphConvActivation(self.input_dim[nt_col],
  110. self.output_dim[nt_row], rel.adjacency_matrix,
  111. self.keep_prob, self.rel_activation)
  112. self.next_layer_repr[nt_row].append((conv, nt_col))
  113. if nt_row == nt_col:
  114. continue
  115. conv = DropoutGraphConvActivation(self.input_dim[nt_row],
  116. self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
  117. self.keep_prob, self.rel_activation)
  118. self.next_layer_repr[nt_col].append((conv, nt_row))
  119. def __call__(self):
  120. prev_layer_repr = self.previous_layer()
  121. next_layer_repr = [None] * len(self.data.node_types)
  122. print('next_layer_repr:', next_layer_repr)
  123. for i in range(len(self.data.node_types)):
  124. next_layer_repr[i] = [
  125. conv(prev_layer_repr[neighbor_type]) \
  126. for (conv, neighbor_type) in \
  127. self.next_layer_repr[i]
  128. ]
  129. next_layer_repr[i] = sum(next_layer_repr[i])
  130. next_layer_repr[i] = torch.nn.functional.normalize(next_layer_repr[i], p=2, dim=1)
  131. print('next_layer_repr:', next_layer_repr)
  132. # next_layer_repr = list(map(sum, next_layer_repr))
  133. return next_layer_repr