IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

301 lines
8.4KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. """
  6. This module implements the basic convolutional blocks of Decagon.
  7. Just as a quick reminder, the basic convolution formula here is:
  8. y = A * (x * W)
  9. where:
  10. W is a weight matrix
  11. A is an adjacency matrix
  12. x is a matrix of latent representations of a particular type of neighbors.
  13. As we have x here twice, a trick is obviously necessary for this to work.
  14. A must be previously normalized with:
  15. c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  16. or
  17. c_{r}^{i} = 1/|N_{r}^{i}|
  18. Let's work through this step by step to convince ourselves that the
  19. formula is correct.
  20. x = [
  21. [0, 1, 0, 1],
  22. [1, 1, 1, 0],
  23. [0, 0, 0, 1]
  24. ]
  25. W = [
  26. [0, 1],
  27. [1, 0],
  28. [0.5, 0.5],
  29. [0.25, 0.75]
  30. ]
  31. A = [
  32. [0, 1, 0],
  33. [1, 0, 1],
  34. [0, 1, 0]
  35. ]
  36. so the graph looks like this:
  37. (0) -- (1) -- (2)
  38. and therefore the representations in the next layer should be:
  39. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
  40. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
  41. c_{r}^{1} * h_{1}^{k}
  42. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
  43. In actual Decagon code we can see that that latter part propagating directly
  44. the old representation is gone. I will try to do the same for now.
  45. So we have to only take care of:
  46. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
  47. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
  48. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
  49. If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
  50. A = A + eye(len(A))
  51. rowsum = A.sum(1)
  52. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  53. A = dot(A, deg_mat_inv_sqrt)
  54. A = A.transpose()
  55. A = A.dot(deg_mat_inv_sqrt)
  56. Let's see what gives in our case:
  57. A = A + eye(len(A))
  58. [
  59. [1, 1, 0],
  60. [1, 1, 1],
  61. [0, 1, 1]
  62. ]
  63. rowsum = A.sum(1)
  64. [2, 3, 2]
  65. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  66. [
  67. [1./sqrt(2), 0, 0],
  68. [0, 1./sqrt(3), 0],
  69. [0, 0, 1./sqrt(2)]
  70. ]
  71. A = dot(A, deg_mat_inv_sqrt)
  72. [
  73. [ 1/sqrt(2), 1/sqrt(3), 0 ],
  74. [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
  75. [ 0, 1/sqrt(3), 1/sqrt(2) ]
  76. ]
  77. A = A.transpose()
  78. [
  79. [ 1/sqrt(2), 1/sqrt(2), 0 ],
  80. [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
  81. [ 0, 1/sqrt(2), 1/sqrt(2) ]
  82. ]
  83. A = A.dot(deg_mat_inv_sqrt)
  84. [
  85. [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
  86. [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
  87. [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
  88. ]
  89. thus:
  90. [
  91. [0.5 , 0.40824829, 0. ],
  92. [0.40824829, 0.33333333, 0.40824829],
  93. [0. , 0.40824829, 0.5 ]
  94. ]
  95. This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
  96. Then, we get back to the main calculation:
  97. y = x * W
  98. y = A * y
  99. y = x * W
  100. [
  101. [ 1.25, 0.75 ],
  102. [ 1.5 , 1.5 ],
  103. [ 0.25, 0.75 ]
  104. ]
  105. y = A * y
  106. [
  107. 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
  108. 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
  109. 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
  110. ]
  111. that is:
  112. [
  113. [1.23737243, 0.98737244],
  114. [1.11237243, 1.11237243],
  115. [0.73737244, 0.98737244]
  116. ].
  117. All checks out nicely, good.
  118. """
  119. import torch
  120. from .dropout import dropout_sparse, \
  121. dropout
  122. from .weights import init_glorot
  123. from typing import List, Callable
  124. class SparseGraphConv(torch.nn.Module):
  125. """Convolution layer for sparse inputs."""
  126. def __init__(self, in_channels: int, out_channels: int,
  127. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  128. super().__init__(**kwargs)
  129. self.in_channels = in_channels
  130. self.out_channels = out_channels
  131. self.weight = init_glorot(in_channels, out_channels)
  132. self.adjacency_matrix = adjacency_matrix
  133. def forward(self, x: torch.Tensor) -> torch.Tensor:
  134. x = torch.sparse.mm(x, self.weight)
  135. x = torch.sparse.mm(self.adjacency_matrix, x)
  136. return x
  137. class SparseDropoutGraphConvActivation(torch.nn.Module):
  138. def __init__(self, input_dim: int, output_dim: int,
  139. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  140. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  141. **kwargs) -> None:
  142. super().__init__(**kwargs)
  143. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  144. self.keep_prob = keep_prob
  145. self.activation = activation
  146. def forward(self, x: torch.Tensor) -> torch.Tensor:
  147. x = dropout_sparse(x, self.keep_prob)
  148. x = self.sparse_graph_conv(x)
  149. x = self.activation(x)
  150. return x
  151. class SparseMultiDGCA(torch.nn.Module):
  152. def __init__(self, input_dim: List[int], output_dim: int,
  153. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  154. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  155. **kwargs) -> None:
  156. super().__init__(**kwargs)
  157. self.input_dim = input_dim
  158. self.output_dim = output_dim
  159. self.adjacency_matrices = adjacency_matrices
  160. self.keep_prob = keep_prob
  161. self.activation = activation
  162. self.sparse_dgca = None
  163. self.build()
  164. def build(self):
  165. if len(self.input_dim) != len(self.adjacency_matrices):
  166. raise ValueError('input_dim must have the same length as adjacency_matrices')
  167. self.sparse_dgca = []
  168. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  169. self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  170. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  171. if not isinstance(x, list):
  172. raise ValueError('x must be a list of tensors')
  173. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  174. for i, f in enumerate(self.sparse_dgca):
  175. out += f(x[i])
  176. out = torch.nn.functional.normalize(out, p=2, dim=1)
  177. return out
  178. class GraphConv(torch.nn.Module):
  179. def __init__(self, in_channels: int, out_channels: int,
  180. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  181. super().__init__(**kwargs)
  182. self.in_channels = in_channels
  183. self.out_channels = out_channels
  184. self.weight = init_glorot(in_channels, out_channels)
  185. self.adjacency_matrix = adjacency_matrix
  186. def forward(self, x: torch.Tensor) -> torch.Tensor:
  187. x = torch.mm(x, self.weight)
  188. x = torch.mm(self.adjacency_matrix, x)
  189. return x
  190. class DropoutGraphConvActivation(torch.nn.Module):
  191. def __init__(self, input_dim: int, output_dim: int,
  192. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  193. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  194. **kwargs) -> None:
  195. super().__init__(**kwargs)
  196. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  197. self.keep_prob = keep_prob
  198. self.activation = activation
  199. def forward(self, x: torch.Tensor) -> torch.Tensor:
  200. x = dropout(x, keep_prob=self.keep_prob)
  201. x = self.graph_conv(x)
  202. x = self.activation(x)
  203. return x
  204. class MultiDGCA(torch.nn.Module):
  205. def __init__(self, input_dim: List[int], output_dim: int,
  206. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  207. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  208. **kwargs) -> None:
  209. super().__init__(**kwargs)
  210. self.input_dim = input_dim
  211. self.output_dim = output_dim
  212. self.adjacency_matrices = adjacency_matrices
  213. self.keep_prob = keep_prob
  214. self.activation = activation
  215. self.dgca = None
  216. self.build()
  217. def build(self):
  218. if len(self.input_dim) != len(self.adjacency_matrices):
  219. raise ValueError('input_dim must have the same length as adjacency_matrices')
  220. self.dgca = []
  221. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  222. self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  223. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  224. if not isinstance(x, list):
  225. raise ValueError('x must be a list of tensors')
  226. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  227. for i, f in enumerate(self.dgca):
  228. out += f(x[i])
  229. out = torch.nn.functional.normalize(out, p=2, dim=1)
  230. return out