IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

304 linhas
8.5KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. """
  6. This module implements the basic convolutional blocks of Decagon.
  7. Just as a quick reminder, the basic convolution formula here is:
  8. y = A * (x * W)
  9. where:
  10. W is a weight matrix
  11. A is an adjacency matrix
  12. x is a matrix of latent representations of a particular type of neighbors.
  13. As we have x here twice, a trick is obviously necessary for this to work.
  14. A must be previously normalized with:
  15. c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  16. or
  17. c_{r}^{i} = 1/|N_{r}^{i}|
  18. Let's work through this step by step to convince ourselves that the
  19. formula is correct.
  20. x = [
  21. [0, 1, 0, 1],
  22. [1, 1, 1, 0],
  23. [0, 0, 0, 1]
  24. ]
  25. W = [
  26. [0, 1],
  27. [1, 0],
  28. [0.5, 0.5],
  29. [0.25, 0.75]
  30. ]
  31. A = [
  32. [0, 1, 0],
  33. [1, 0, 1],
  34. [0, 1, 0]
  35. ]
  36. so the graph looks like this:
  37. (0) -- (1) -- (2)
  38. and therefore the representations in the next layer should be:
  39. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
  40. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
  41. c_{r}^{1} * h_{1}^{k}
  42. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
  43. In actual Decagon code we can see that that latter part propagating directly
  44. the old representation is gone. I will try to do the same for now.
  45. So we have to only take care of:
  46. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
  47. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
  48. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
  49. If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
  50. A = A + eye(len(A))
  51. rowsum = A.sum(1)
  52. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  53. A = dot(A, deg_mat_inv_sqrt)
  54. A = A.transpose()
  55. A = A.dot(deg_mat_inv_sqrt)
  56. Let's see what gives in our case:
  57. A = A + eye(len(A))
  58. [
  59. [1, 1, 0],
  60. [1, 1, 1],
  61. [0, 1, 1]
  62. ]
  63. rowsum = A.sum(1)
  64. [2, 3, 2]
  65. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  66. [
  67. [1./sqrt(2), 0, 0],
  68. [0, 1./sqrt(3), 0],
  69. [0, 0, 1./sqrt(2)]
  70. ]
  71. A = dot(A, deg_mat_inv_sqrt)
  72. [
  73. [ 1/sqrt(2), 1/sqrt(3), 0 ],
  74. [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
  75. [ 0, 1/sqrt(3), 1/sqrt(2) ]
  76. ]
  77. A = A.transpose()
  78. [
  79. [ 1/sqrt(2), 1/sqrt(2), 0 ],
  80. [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
  81. [ 0, 1/sqrt(2), 1/sqrt(2) ]
  82. ]
  83. A = A.dot(deg_mat_inv_sqrt)
  84. [
  85. [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
  86. [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
  87. [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
  88. ]
  89. thus:
  90. [
  91. [0.5 , 0.40824829, 0. ],
  92. [0.40824829, 0.33333333, 0.40824829],
  93. [0. , 0.40824829, 0.5 ]
  94. ]
  95. This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
  96. Then, we get back to the main calculation:
  97. y = x * W
  98. y = A * y
  99. y = x * W
  100. [
  101. [ 1.25, 0.75 ],
  102. [ 1.5 , 1.5 ],
  103. [ 0.25, 0.75 ]
  104. ]
  105. y = A * y
  106. [
  107. 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
  108. 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
  109. 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
  110. ]
  111. that is:
  112. [
  113. [1.23737243, 0.98737244],
  114. [1.11237243, 1.11237243],
  115. [0.73737244, 0.98737244]
  116. ].
  117. All checks out nicely, good.
  118. """
  119. import torch
  120. from .dropout import dropout_sparse, \
  121. dropout
  122. from .weights import init_glorot
  123. from typing import List, Callable
  124. class SparseGraphConv(torch.nn.Module):
  125. """Convolution layer for sparse inputs."""
  126. def __init__(self, in_channels: int, out_channels: int,
  127. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  128. super().__init__(**kwargs)
  129. self.in_channels = in_channels
  130. self.out_channels = out_channels
  131. self.weight = init_glorot(in_channels, out_channels)
  132. self.adjacency_matrix = adjacency_matrix
  133. def forward(self, x: torch.Tensor) -> torch.Tensor:
  134. x = torch.sparse.mm(x, self.weight)
  135. x = torch.sparse.mm(self.adjacency_matrix, x)
  136. return x
  137. class SparseDropoutGraphConvActivation(torch.nn.Module):
  138. def __init__(self, input_dim: int, output_dim: int,
  139. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  140. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  141. **kwargs) -> None:
  142. super().__init__(**kwargs)
  143. self.input_dim = input_dim
  144. self.output_dim = output_dim
  145. self.adjacency_matrix = adjacency_matrix
  146. self.keep_prob = keep_prob
  147. self.activation = activation
  148. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  149. def forward(self, x: torch.Tensor) -> torch.Tensor:
  150. x = dropout_sparse(x, self.keep_prob)
  151. x = self.sparse_graph_conv(x)
  152. x = self.activation(x)
  153. return x
  154. class SparseMultiDGCA(torch.nn.Module):
  155. def __init__(self, input_dim: List[int], output_dim: int,
  156. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  157. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  158. **kwargs) -> None:
  159. super().__init__(**kwargs)
  160. self.input_dim = input_dim
  161. self.output_dim = output_dim
  162. self.adjacency_matrices = adjacency_matrices
  163. self.keep_prob = keep_prob
  164. self.activation = activation
  165. self.sparse_dgca = None
  166. self.build()
  167. def build(self):
  168. if len(self.input_dim) != len(self.adjacency_matrices):
  169. raise ValueError('input_dim must have the same length as adjacency_matrices')
  170. self.sparse_dgca = []
  171. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  172. self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  173. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  174. if not isinstance(x, list):
  175. raise ValueError('x must be a list of tensors')
  176. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  177. for i, f in enumerate(self.sparse_dgca):
  178. out += f(x[i])
  179. out = torch.nn.functional.normalize(out, p=2, dim=1)
  180. return out
  181. class DenseGraphConv(torch.nn.Module):
  182. def __init__(self, in_channels: int, out_channels: int,
  183. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  184. super().__init__(**kwargs)
  185. self.in_channels = in_channels
  186. self.out_channels = out_channels
  187. self.weight = init_glorot(in_channels, out_channels)
  188. self.adjacency_matrix = adjacency_matrix
  189. def forward(self, x: torch.Tensor) -> torch.Tensor:
  190. x = torch.mm(x, self.weight)
  191. x = torch.mm(self.adjacency_matrix, x)
  192. return x
  193. class DenseDropoutGraphConvActivation(torch.nn.Module):
  194. def __init__(self, input_dim: int, output_dim: int,
  195. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  196. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  197. **kwargs) -> None:
  198. super().__init__(**kwargs)
  199. self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix)
  200. self.keep_prob = keep_prob
  201. self.activation = activation
  202. def forward(self, x: torch.Tensor) -> torch.Tensor:
  203. x = dropout(x, keep_prob=self.keep_prob)
  204. x = self.graph_conv(x)
  205. x = self.activation(x)
  206. return x
  207. class DenseMultiDGCA(torch.nn.Module):
  208. def __init__(self, input_dim: List[int], output_dim: int,
  209. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  210. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  211. **kwargs) -> None:
  212. super().__init__(**kwargs)
  213. self.input_dim = input_dim
  214. self.output_dim = output_dim
  215. self.adjacency_matrices = adjacency_matrices
  216. self.keep_prob = keep_prob
  217. self.activation = activation
  218. self.dgca = None
  219. self.build()
  220. def build(self):
  221. if len(self.input_dim) != len(self.adjacency_matrices):
  222. raise ValueError('input_dim must have the same length as adjacency_matrices')
  223. self.dgca = []
  224. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  225. self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  226. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  227. if not isinstance(x, list):
  228. raise ValueError('x must be a list of tensors')
  229. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  230. for i, f in enumerate(self.dgca):
  231. out += f(x[i])
  232. out = torch.nn.functional.normalize(out, p=2, dim=1)
  233. return out