IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

347 lines
10KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. """
  6. This module implements the basic convolutional blocks of Decagon.
  7. Just as a quick reminder, the basic convolution formula here is:
  8. y = A * (x * W)
  9. where:
  10. W is a weight matrix
  11. A is an adjacency matrix
  12. x is a matrix of latent representations of a particular type of neighbors.
  13. As we have x here twice, a trick is obviously necessary for this to work.
  14. A must be previously normalized with:
  15. c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  16. or
  17. c_{r}^{i} = 1/|N_{r}^{i}|
  18. Let's work through this step by step to convince ourselves that the
  19. formula is correct.
  20. x = [
  21. [0, 1, 0, 1],
  22. [1, 1, 1, 0],
  23. [0, 0, 0, 1]
  24. ]
  25. W = [
  26. [0, 1],
  27. [1, 0],
  28. [0.5, 0.5],
  29. [0.25, 0.75]
  30. ]
  31. A = [
  32. [0, 1, 0],
  33. [1, 0, 1],
  34. [0, 1, 0]
  35. ]
  36. so the graph looks like this:
  37. (0) -- (1) -- (2)
  38. and therefore the representations in the next layer should be:
  39. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
  40. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
  41. c_{r}^{1} * h_{1}^{k}
  42. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
  43. In actual Decagon code we can see that that latter part propagating directly
  44. the old representation is gone. I will try to do the same for now.
  45. So we have to only take care of:
  46. h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
  47. h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
  48. h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
  49. If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
  50. A = A + eye(len(A))
  51. rowsum = A.sum(1)
  52. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  53. A = dot(A, deg_mat_inv_sqrt)
  54. A = A.transpose()
  55. A = A.dot(deg_mat_inv_sqrt)
  56. Let's see what gives in our case:
  57. A = A + eye(len(A))
  58. [
  59. [1, 1, 0],
  60. [1, 1, 1],
  61. [0, 1, 1]
  62. ]
  63. rowsum = A.sum(1)
  64. [2, 3, 2]
  65. deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
  66. [
  67. [1./sqrt(2), 0, 0],
  68. [0, 1./sqrt(3), 0],
  69. [0, 0, 1./sqrt(2)]
  70. ]
  71. A = dot(A, deg_mat_inv_sqrt)
  72. [
  73. [ 1/sqrt(2), 1/sqrt(3), 0 ],
  74. [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
  75. [ 0, 1/sqrt(3), 1/sqrt(2) ]
  76. ]
  77. A = A.transpose()
  78. [
  79. [ 1/sqrt(2), 1/sqrt(2), 0 ],
  80. [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
  81. [ 0, 1/sqrt(2), 1/sqrt(2) ]
  82. ]
  83. A = A.dot(deg_mat_inv_sqrt)
  84. [
  85. [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
  86. [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
  87. [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
  88. ]
  89. thus:
  90. [
  91. [0.5 , 0.40824829, 0. ],
  92. [0.40824829, 0.33333333, 0.40824829],
  93. [0. , 0.40824829, 0.5 ]
  94. ]
  95. This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
  96. Then, we get back to the main calculation:
  97. y = x * W
  98. y = A * y
  99. y = x * W
  100. [
  101. [ 1.25, 0.75 ],
  102. [ 1.5 , 1.5 ],
  103. [ 0.25, 0.75 ]
  104. ]
  105. y = A * y
  106. [
  107. 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
  108. 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
  109. 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
  110. ]
  111. that is:
  112. [
  113. [1.23737243, 0.98737244],
  114. [1.11237243, 1.11237243],
  115. [0.73737244, 0.98737244]
  116. ].
  117. All checks out nicely, good.
  118. """
  119. import torch
  120. from .dropout import dropout_sparse, \
  121. dropout
  122. from .weights import init_glorot
  123. from typing import List, Callable
  124. class SparseGraphConv(torch.nn.Module):
  125. """Convolution layer for sparse inputs."""
  126. def __init__(self, in_channels: int, out_channels: int,
  127. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  128. super().__init__(**kwargs)
  129. self.in_channels = in_channels
  130. self.out_channels = out_channels
  131. self.weight = init_glorot(in_channels, out_channels)
  132. self.adjacency_matrix = adjacency_matrix
  133. def forward(self, x: torch.Tensor) -> torch.Tensor:
  134. x = torch.sparse.mm(x, self.weight)
  135. x = torch.sparse.mm(self.adjacency_matrix, x)
  136. return x
  137. class SparseDropoutGraphConvActivation(torch.nn.Module):
  138. def __init__(self, input_dim: int, output_dim: int,
  139. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  140. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  141. **kwargs) -> None:
  142. super().__init__(**kwargs)
  143. self.input_dim = input_dim
  144. self.output_dim = output_dim
  145. self.adjacency_matrix = adjacency_matrix
  146. self.keep_prob = keep_prob
  147. self.activation = activation
  148. self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
  149. def forward(self, x: torch.Tensor) -> torch.Tensor:
  150. x = dropout_sparse(x, self.keep_prob)
  151. x = self.sparse_graph_conv(x)
  152. x = self.activation(x)
  153. return x
  154. class SparseMultiDGCA(torch.nn.Module):
  155. def __init__(self, input_dim: List[int], output_dim: int,
  156. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  157. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  158. **kwargs) -> None:
  159. super().__init__(**kwargs)
  160. self.input_dim = input_dim
  161. self.output_dim = output_dim
  162. self.adjacency_matrices = adjacency_matrices
  163. self.keep_prob = keep_prob
  164. self.activation = activation
  165. self.sparse_dgca = None
  166. self.build()
  167. def build(self):
  168. if len(self.input_dim) != len(self.adjacency_matrices):
  169. raise ValueError('input_dim must have the same length as adjacency_matrices')
  170. self.sparse_dgca = []
  171. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  172. self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  173. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  174. if not isinstance(x, list):
  175. raise ValueError('x must be a list of tensors')
  176. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  177. for i, f in enumerate(self.sparse_dgca):
  178. out += f(x[i])
  179. out = torch.nn.functional.normalize(out, p=2, dim=1)
  180. return out
  181. class DenseGraphConv(torch.nn.Module):
  182. def __init__(self, in_channels: int, out_channels: int,
  183. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  184. super().__init__(**kwargs)
  185. self.in_channels = in_channels
  186. self.out_channels = out_channels
  187. self.weight = init_glorot(in_channels, out_channels)
  188. self.adjacency_matrix = adjacency_matrix
  189. def forward(self, x: torch.Tensor) -> torch.Tensor:
  190. x = torch.mm(x, self.weight)
  191. x = torch.mm(self.adjacency_matrix, x)
  192. return x
  193. class DenseDropoutGraphConvActivation(torch.nn.Module):
  194. def __init__(self, input_dim: int, output_dim: int,
  195. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  196. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  197. **kwargs) -> None:
  198. super().__init__(**kwargs)
  199. self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix)
  200. self.keep_prob = keep_prob
  201. self.activation = activation
  202. def forward(self, x: torch.Tensor) -> torch.Tensor:
  203. x = dropout(x, keep_prob=self.keep_prob)
  204. x = self.graph_conv(x)
  205. x = self.activation(x)
  206. return x
  207. class DenseMultiDGCA(torch.nn.Module):
  208. def __init__(self, input_dim: List[int], output_dim: int,
  209. adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
  210. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  211. **kwargs) -> None:
  212. super().__init__(**kwargs)
  213. self.input_dim = input_dim
  214. self.output_dim = output_dim
  215. self.adjacency_matrices = adjacency_matrices
  216. self.keep_prob = keep_prob
  217. self.activation = activation
  218. self.dgca = None
  219. self.build()
  220. def build(self):
  221. if len(self.input_dim) != len(self.adjacency_matrices):
  222. raise ValueError('input_dim must have the same length as adjacency_matrices')
  223. self.dgca = []
  224. for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
  225. self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
  226. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  227. if not isinstance(x, list):
  228. raise ValueError('x must be a list of tensors')
  229. out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
  230. for i, f in enumerate(self.dgca):
  231. out += f(x[i])
  232. out = torch.nn.functional.normalize(out, p=2, dim=1)
  233. return out
  234. class GraphConv(torch.nn.Module):
  235. """Convolution layer for sparse AND dense inputs."""
  236. def __init__(self, in_channels: int, out_channels: int,
  237. adjacency_matrix: torch.Tensor, **kwargs) -> None:
  238. super().__init__(**kwargs)
  239. self.in_channels = in_channels
  240. self.out_channels = out_channels
  241. self.weight = init_glorot(in_channels, out_channels)
  242. self.adjacency_matrix = adjacency_matrix
  243. def forward(self, x: torch.Tensor) -> torch.Tensor:
  244. x = torch.sparse.mm(x, self.weight) \
  245. if x.is_sparse \
  246. else torch.mm(x, self.weight)
  247. x = torch.sparse.mm(self.adjacency_matrix, x) \
  248. if self.adjacency_matrix.is_sparse \
  249. else torch.mm(self.adjacency_matrix, x)
  250. return x
  251. class DropoutGraphConvActivation(torch.nn.Module):
  252. def __init__(self, input_dim: int, output_dim: int,
  253. adjacency_matrix: torch.Tensor, keep_prob: float=1.,
  254. activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
  255. **kwargs) -> None:
  256. super().__init__(**kwargs)
  257. self.input_dim = input_dim
  258. self.output_dim = output_dim
  259. self.adjacency_matrix = adjacency_matrix
  260. self.keep_prob = keep_prob
  261. self.activation = activation
  262. self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
  263. def forward(self, x: torch.Tensor) -> torch.Tensor:
  264. x = dropout_sparse(x, self.keep_prob) \
  265. if x.is_sparse \
  266. else dropout(x, self.keep_prob)
  267. x = self.graph_conv(x)
  268. x = self.activation(x)
  269. return x