IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

285 linhas
7.7KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. """
  6. This module implements the basic convolutional blocks of Decagon.
  7. Just as a quick reminder, the basic convolution formula here is:
  8. y = A * (x * W)
  9. where:
  10. W is a weight matrix
  11. A is an adjacency matrix
  12. x is a matrix of latent representations of a particular type of neighbors.
  13. As we have x here twice, a trick is obviously necessary for this to work.
  14. A must be previously normalized with:
  15. c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  16. or
  17. c_{r}^{i} = 1/|N_{r}^{i}|
  18. Let's work through this step by step to convince ourselves that the
  19. formula is correct.
  20. x = [
  21. [0, 1, 0, 1],
  22. [1, 1, 1, 0],
  23. [0, 0, 0, 1]
  24. ]
  25. W = [
  26. [0, 1],
  27. [1, 0],
  28. [0.5, 0.5],
  29. [0.25, 0.75]
  30. ]
  31. A = [
  32. [0, 1, 0],
  33. [1, 0, 1],
  34. [0, 1, 0]
  35. ]
  36. so the graph looks like this:
  37. (0) -- (1) -- (2)
  38. and therefore the representations in the next layer should be:
  39. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
  40. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
  41. c_{r}^{1} * h_{1}^{k}
  42. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
  43. In actual Decagon code we can see that that latter part propagating directly
  44. the old representation is gone. I will try to do the same for now.
  45. So we have to only take care of:
  46. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
  47. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
  48. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
  49. If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
  50. A = A + eye(len(A))
  51. rowsum = A.sum(1)
  52. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  53. A = dot(A, deg_mat_inv_sqrt)
  54. A = A.transpose()
  55. A = A.dot(deg_mat_inv_sqrt)
  56. Let's see what gives in our case:
  57. A = A + eye(len(A))
  58. [
  59. [1, 1, 0],
  60. [1, 1, 1],
  61. [0, 1, 1]
  62. ]
  63. rowsum = A.sum(1)
  64. [2, 3, 2]
  65. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  66. [
  67. [1./sqrt(2), 0, 0],
  68. [0, 1./sqrt(3), 0],
  69. [0, 0, 1./sqrt(2)]
  70. ]
  71. A = dot(A, deg_mat_inv_sqrt)
  72. [
  73. [ 1/sqrt(2), 1/sqrt(3), 0 ],
  74. [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
  75. [ 0, 1/sqrt(3), 1/sqrt(2) ]
  76. ]
  77. A = A.transpose()
  78. [
  79. [ 1/sqrt(2), 1/sqrt(2), 0 ],
  80. [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
  81. [ 0, 1/sqrt(2), 1/sqrt(2) ]
  82. ]
  83. A = A.dot(deg_mat_inv_sqrt)
  84. [
  85. [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
  86. [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
  87. [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
  88. ]
  89. thus:
  90. [
  91. [0.5 , 0.40824829, 0. ],
  92. [0.40824829, 0.33333333, 0.40824829],
  93. [0. , 0.40824829, 0.5 ]
  94. ]
  95. This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
  96. Then, we get back to the main calculation:
  97. y = x * W
  98. y = A * y
  99. y = x * W
  100. [
  101. [ 1.25, 0.75 ],
  102. [ 1.5 , 1.5 ],
  103. [ 0.25, 0.75 ]
  104. ]
  105. y = A * y
  106. [
  107. 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
  108. 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
  109. 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
  110. ]
  111. that is:
  112. [
  113. [1.23737243, 0.98737244],
  114. [1.11237243, 1.11237243],
  115. [0.73737244, 0.98737244]
  116. ].
  117. All checks out nicely, good.
  118. """
  119. import torch
  120. from .dropout import dropout_sparse, \
  121. dropout
  122. from .weights import init_glorot
  123. from typing import List, Callable
  124. class SparseGraphConv(torch.nn.Module):
  125. """Convolution layer for sparse inputs."""
  126. def __init__(self, in_channels: int, out_channels: int,
  127. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  128. super().__init__(**kwargs)
  129. self.in_channels = in_channels
  130. self.out_channels = out_channels
  131. self.weight = init_glorot(in_channels, out_channels)
  132. self.adjacency_matrix = adjacency_matrix
  133. def forward(self, x: torch.Tensor) -> torch.Tensor:
  134. x = torch.sparse.mm(x, self.weight)
  135. x = torch.sparse.mm(self.adjacency_matrix, x)
  136. return x
  137. class SparseDropoutGraphConvActivation(torch.nn.Module):
  138. def __init__(self, input_dim: int, output_dim: int,
  139. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  140. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  141. **kwargs) -> None:
  142. super().__init__(**kwargs)
  143. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  144. self.keep_prob = keep_prob
  145. self.activation = activation
  146. def forward(self, x: torch.Tensor) -> torch.Tensor:
  147. x = dropout_sparse(x, self.keep_prob)
  148. x = self.sparse_graph_conv(x)
  149. x = self.activation(x)
  150. return x
  151. class SparseMultiDGCA(torch.nn.Module):
  152. def __init__(self, input_dim: List[int], output_dim: int,
  153. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  154. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  155. **kwargs) -> None:
  156. super().__init__(**kwargs)
  157. self.input_dim = input_dim
  158. self.output_dim = output_dim
  159. self.adjacency_matrices = adjacency_matrices
  160. self.keep_prob = keep_prob
  161. self.activation = activation
  162. self.sparse_dgca = None
  163. self.build()
  164. def build(self):
  165. if len(self.input_dim) != len(self.adjacency_matrices):
  166. raise ValueError('input_dim must have the same length as adjacency_matrices')
  167. self.sparse_dgca = []
  168. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  169. self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  170. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  171. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  172. for i, f in enumerate(self.sparse_dgca):
  173. out += f(x[i])
  174. out = torch.nn.functional.normalize(out, p=2, dim=1)
  175. return out
  176. class GraphConv(torch.nn.Module):
  177. def __init__(self, in_channels: int, out_channels: int,
  178. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  179. super().__init__(**kwargs)
  180. self.in_channels = in_channels
  181. self.out_channels = out_channels
  182. self.weight = init_glorot(in_channels, out_channels)
  183. self.adjacency_matrix = adjacency_matrix
  184. def forward(self, x: torch.Tensor) -> torch.Tensor:
  185. x = torch.mm(x, self.weight)
  186. x = torch.mm(self.adjacency_matrix, x)
  187. return x
  188. class DropoutGraphConvActivation(torch.nn.Module):
  189. def __init__(self, input_dim: int, output_dim: int,
  190. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  191. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  192. **kwargs) -> None:
  193. super().__init__(**kwargs)
  194. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  195. self.keep_prob = keep_prob
  196. self.activation = activation
  197. def forward(self, x: torch.Tensor) -> torch.Tensor:
  198. x = dropout(x, keep_prob=self.keep_prob)
  199. x = self.graph_conv(x)
  200. x = self.activation(x)
  201. return x
  202. class MultiDGCA(torch.nn.Module):
  203. def __init__(self, input_dim: List[int], output_dim: int,
  204. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  205. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  206. **kwargs) -> None:
  207. super().__init__(**kwargs)
  208. self.output_dim = output_dim
  209. self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
  210. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  211. out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
  212. for f in self.dgca:
  213. out += f(x)
  214. out = torch.nn.functional.normalize(out, p=2, dim=1)
  215. return out