IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

196 lines
7.2KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. from .util import _sparse_diag_cat,
  13. _cat
  14. class FastGraphConv(torch.nn.Module):
  15. def __init__(self,
  16. in_channels: int,
  17. out_channels: int,
  18. adjacency_matrices: List[torch.Tensor],
  19. keep_prob: float = 1.,
  20. activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  21. **kwargs) -> None:
  22. super().__init__(**kwargs)
  23. in_channels = int(in_channels)
  24. out_channels = int(out_channels)
  25. if not isinstance(adjacency_matrices, list):
  26. raise TypeError('adjacency_matrices must be a list')
  27. if len(adjacency_matrices) == 0:
  28. raise ValueError('adjacency_matrices must not be empty')
  29. if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
  30. raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
  31. if not all(m.is_sparse for m in adjacency_matrices):
  32. raise ValueError('adjacency_matrices elements must be sparse')
  33. keep_prob = float(keep_prob)
  34. if not isinstance(activation, types.FunctionType):
  35. raise TypeError('activation must be a function')
  36. self.in_channels = in_channels
  37. self.out_channels = out_channels
  38. self.adjacency_matrices = adjacency_matrices
  39. self.keep_prob = keep_prob
  40. self.activation = activation
  41. self.num_row_nodes = len(adjacency_matrices[0])
  42. self.num_relation_types = len(adjacency_matrices)
  43. self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
  44. self.weights = torch.cat([
  45. init_glorot(in_channels, out_channels) \
  46. for _ in range(self.num_relation_types)
  47. ], dim=1)
  48. def forward(self, x) -> torch.Tensor:
  49. if self.keep_prob < 1.:
  50. x = dropout(x, self.keep_prob)
  51. res = torch.sparse.mm(x, self.weights) \
  52. if x.is_sparse \
  53. else torch.mm(x, self.weights)
  54. res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
  55. res = torch.cat(res)
  56. res = torch.sparse.mm(self.adjacency_matrices, res) \
  57. if self.adjacency_matrices.is_sparse \
  58. else torch.mm(self.adjacency_matrices, res)
  59. res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
  60. if self.activation is not None:
  61. res = self.activation(res)
  62. return res
  63. class FastConvLayer(torch.nn.Module):
  64. def __init__(self,
  65. input_dim: List[int],
  66. output_dim: List[int],
  67. data: Union[Data, PreparedData],
  68. keep_prob: float = 1.,
  69. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  70. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  71. **kwargs):
  72. super().__init__(**kwargs)
  73. self._check_params(input_dim, output_dim, data, keep_prob,
  74. rel_activation, layer_activation)
  75. self.input_dim = input_dim
  76. self.output_dim = output_dim
  77. self.data = data
  78. self.keep_prob = keep_prob
  79. self.rel_activation = rel_activation
  80. self.layer_activation = layer_activation
  81. self.is_sparse = False
  82. self.next_layer_repr = None
  83. self.build()
  84. def build(self):
  85. self.next_layer_repr = torch.nn.ModuleList([
  86. torch.nn.ModuleList() \
  87. for _ in range(len(self.data.node_types))
  88. ])
  89. for fam in self.data.relation_families:
  90. self.build_family(fam)
  91. def build_family(self, fam) -> None:
  92. if fam.node_type_row == fam.node_type_column:
  93. self.build_fam_one_node_type(fam)
  94. else:
  95. self.build_fam_two_node_types(fam)
  96. def build_fam_one_node_type(self, fam) -> None:
  97. adjacency_matrices = [
  98. r.adjacency_matrix \
  99. for r in fam.relation_types
  100. ]
  101. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  102. self.output_dim[fam.node_type_row],
  103. adjacency_matrices,
  104. self.keep_prob,
  105. self.rel_activation)
  106. conv.input_node_type = fam.node_type_column
  107. self.next_layer_repr[fam.node_type_row].append(conv)
  108. def build_fam_two_node_types(self, fam) -> None:
  109. adjacency_matrices = [
  110. r.adjacency_matrix \
  111. for r in fam.relation_types \
  112. if r.adjacency_matrix is not None
  113. ]
  114. adjacency_matrices_backward = [
  115. r.adjacency_matrix_backward \
  116. for r in fam.relation_types \
  117. if r.adjacency_matrix_backward is not None
  118. ]
  119. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  120. self.output_dim[fam.node_type_row],
  121. adjacency_matrices,
  122. self.keep_prob,
  123. self.rel_activation)
  124. conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
  125. self.output_dim[fam.node_type_column],
  126. adjacency_matrices_backward,
  127. self.keep_prob,
  128. self.rel_activation)
  129. conv.input_node_type = fam.node_type_column
  130. conv_backward.input_node_type = fam.node_type_row
  131. self.next_layer_repr[fam.node_type_row].append(conv)
  132. self.next_layer_repr[fam.node_type_column].append(conv_backward)
  133. def forward(self, prev_layer_repr):
  134. next_layer_repr = [ [] \
  135. for _ in range(len(self.data.node_types)) ]
  136. for output_node_type in range(len(self.data.node_types)):
  137. for conv in self.next_layer_repr[output_node_type]:
  138. rep = conv(prev_layer_repr[conv.input_node_type])
  139. rep = torch.sum(rep, dim=0)
  140. rep = torch.nn.functional.normalize(rep, p=2, dim=1)
  141. next_layer_repr[output_node_type].append(rep)
  142. if len(next_layer_repr[output_node_type]) == 0:
  143. next_layer_repr[output_node_type] = \
  144. torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
  145. else:
  146. next_layer_repr[output_node_type] = \
  147. sum(next_layer_repr[output_node_type])
  148. next_layer_repr[output_node_type] = \
  149. self.layer_activation(next_layer_repr[output_node_type])
  150. return next_layer_repr
  151. @staticmethod
  152. def _check_params(input_dim, output_dim, data, keep_prob,
  153. rel_activation, layer_activation):
  154. if not isinstance(input_dim, list):
  155. raise ValueError('input_dim must be a list')
  156. if not output_dim:
  157. raise ValueError('output_dim must be specified')
  158. if not isinstance(output_dim, list):
  159. output_dim = [output_dim] * len(data.node_types)
  160. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  161. raise ValueError('data must be of type Data or PreparedData')