IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

166 lines
6.0KB

  1. #
  2. # This module implements a single layer of the Decagon
  3. # model. This is going to be already quite complex, as
  4. # we will be using all the graph convolutional building
  5. # blocks.
  6. #
  7. # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
  8. # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
  9. #
  10. # N{r}^{i} - set of neighbors of node i under relation r
  11. # W_{r}^(k) - relation-type specific weight matrix
  12. # h_{i}^(k) - hidden state of node i in layer k
  13. # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
  14. # of the representation in k-th layer
  15. # ϕ - activation function
  16. # c_{r}^{ij} - normalization constants
  17. # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  18. # c_{r}^{i} - normalization constants
  19. # c_{r}^{i} = 1/|N_{r}^{i}|
  20. #
  21. import torch
  22. from .convolve import DropoutGraphConvActivation
  23. from .data import Data
  24. from typing import List, \
  25. Union, \
  26. Callable
  27. from collections import defaultdict
  28. class Layer(torch.nn.Module):
  29. def __init__(self,
  30. output_dim: Union[int, List[int]],
  31. is_sparse: bool,
  32. **kwargs) -> None:
  33. super().__init__(**kwargs)
  34. self.output_dim = output_dim
  35. self.is_sparse = is_sparse
  36. class InputLayer(Layer):
  37. def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
  38. output_dim = output_dim or \
  39. list(map(lambda a: a.count, data.node_types))
  40. if not isinstance(output_dim, list):
  41. output_dim = [output_dim,] * len(data.node_types)
  42. super().__init__(output_dim, is_sparse=False, **kwargs)
  43. self.data = data
  44. self.node_reps = None
  45. self.build()
  46. def build(self) -> None:
  47. self.node_reps = []
  48. for i, nt in enumerate(self.data.node_types):
  49. reps = torch.rand(nt.count, self.output_dim[i])
  50. reps = torch.nn.Parameter(reps)
  51. self.register_parameter('node_reps[%d]' % i, reps)
  52. self.node_reps.append(reps)
  53. def forward(self) -> List[torch.nn.Parameter]:
  54. return self.node_reps
  55. def __repr__(self) -> str:
  56. s = ''
  57. s += 'GNN input layer with output_dim: %s\n' % self.output_dim
  58. s += ' # of node types: %d\n' % len(self.data.node_types)
  59. for nt in self.data.node_types:
  60. s += ' - %s (%d)\n' % (nt.name, nt.count)
  61. return s.strip()
  62. class OneHotInputLayer(Layer):
  63. def __init__(self, data: Data, **kwargs) -> None:
  64. output_dim = [ a.count for a in data.node_types ]
  65. super().__init__(output_dim, is_sparse=True, **kwargs)
  66. self.data = data
  67. self.node_reps = None
  68. self.build()
  69. def build(self) -> None:
  70. self.node_reps = []
  71. for i, nt in enumerate(self.data.node_types):
  72. reps = torch.eye(nt.count).to_sparse()
  73. reps = torch.nn.Parameter(reps)
  74. self.register_parameter('node_reps[%d]' % i, reps)
  75. self.node_reps.append(reps)
  76. def forward(self) -> List[torch.nn.Parameter]:
  77. return self.node_reps
  78. def __repr__(self) -> str:
  79. s = ''
  80. s += 'One-hot GNN input layer\n'
  81. s += ' # of node types: %d\n' % len(self.data.node_types)
  82. for nt in self.data.node_types:
  83. s += ' - %s (%d)\n' % (nt.name, nt.count)
  84. return s.strip()
  85. class DecagonLayer(Layer):
  86. def __init__(self,
  87. data: Data,
  88. previous_layer: Layer,
  89. output_dim: Union[int, List[int]],
  90. keep_prob: float = 1.,
  91. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  92. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  93. **kwargs):
  94. if not isinstance(output_dim, list):
  95. output_dim = [ output_dim ] * len(data.node_types)
  96. super().__init__(output_dim, is_sparse=False, **kwargs)
  97. self.data = data
  98. self.previous_layer = previous_layer
  99. self.input_dim = previous_layer.output_dim
  100. self.keep_prob = keep_prob
  101. self.rel_activation = rel_activation
  102. self.layer_activation = layer_activation
  103. self.next_layer_repr = None
  104. self.build()
  105. def build(self):
  106. self.next_layer_repr = defaultdict(list)
  107. for (nt_row, nt_col), relation_types in self.data.relation_types.items():
  108. row_convs = []
  109. col_convs = []
  110. for rel in relation_types:
  111. conv = DropoutGraphConvActivation(self.input_dim[nt_col],
  112. self.output_dim[nt_row], rel.adjacency_matrix,
  113. self.keep_prob, self.rel_activation)
  114. row_convs.append(conv)
  115. if nt_row == nt_col:
  116. continue
  117. conv = DropoutGraphConvActivation(self.input_dim[nt_row],
  118. self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
  119. self.keep_prob, self.rel_activation)
  120. col_convs.append(conv)
  121. self.next_layer_repr[nt_row].append((row_convs, nt_col))
  122. if nt_row == nt_col:
  123. continue
  124. self.next_layer_repr[nt_col].append((col_convs, nt_row))
  125. def __call__(self):
  126. prev_layer_repr = self.previous_layer()
  127. next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
  128. print('next_layer_repr:', next_layer_repr)
  129. for i in range(len(self.data.node_types)):
  130. for convs, neighbor_type in self.next_layer_repr[i]:
  131. convs = [ conv(prev_layer_repr[neighbor_type]) \
  132. for conv in convs ]
  133. convs = sum(convs)
  134. convs = torch.nn.functional.normalize(convs, p=2, dim=1)
  135. next_layer_repr[i].append(convs)
  136. next_layer_repr[i] = sum(next_layer_repr[i])
  137. next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
  138. print('next_layer_repr:', next_layer_repr)
  139. return next_layer_repr