IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
7.2KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data, \
  5. RelationFamily
  6. from .trainprep import PreparedData, \
  7. PreparedRelationFamily
  8. import torch
  9. from .weights import init_glorot
  10. from .normalize import _sparse_coo_tensor
  11. import types
  12. from .util import _sparse_diag_cat,
  13. _cat
  14. class FastGraphConv(torch.nn.Module):
  15. def __init__(self,
  16. in_channels: int,
  17. out_channels: int,
  18. adjacency_matrices: List[torch.Tensor],
  19. keep_prob: float = 1.,
  20. activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  21. **kwargs) -> None:
  22. super().__init__(**kwargs)
  23. in_channels = int(in_channels)
  24. out_channels = int(out_channels)
  25. if not isinstance(adjacency_matrices, list):
  26. raise TypeError('adjacency_matrices must be a list')
  27. if len(adjacency_matrices) == 0:
  28. raise ValueError('adjacency_matrices must not be empty')
  29. if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
  30. raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
  31. if not all(m.is_sparse for m in adjacency_matrices):
  32. raise ValueError('adjacency_matrices elements must be sparse')
  33. keep_prob = float(keep_prob)
  34. if not isinstance(activation, types.FunctionType):
  35. raise TypeError('activation must be a function')
  36. self.in_channels = in_channels
  37. self.out_channels = out_channels
  38. self.adjacency_matrices = adjacency_matrices
  39. self.keep_prob = keep_prob
  40. self.activation = activation
  41. self.num_row_nodes = len(adjacency_matrices[0])
  42. self.num_relation_types = len(adjacency_matrices)
  43. self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
  44. self.weights = torch.cat([
  45. init_glorot(in_channels, out_channels) \
  46. for _ in range(self.num_relation_types)
  47. ], dim=1)
  48. def forward(self, x) -> torch.Tensor:
  49. if self.keep_prob < 1.:
  50. x = dropout(x, self.keep_prob)
  51. res = torch.sparse.mm(x, self.weights) \
  52. if x.is_sparse \
  53. else torch.mm(x, self.weights)
  54. res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
  55. res = torch.cat(res)
  56. res = torch.sparse.mm(self.adjacency_matrices, res) \
  57. if self.adjacency_matrices.is_sparse \
  58. else torch.mm(self.adjacency_matrices, res)
  59. res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
  60. if self.activation is not None:
  61. res = self.activation(res)
  62. return res
  63. class FastConvLayer(torch.nn.Module):
  64. def __init__(self,
  65. input_dim: List[int],
  66. output_dim: List[int],
  67. data: Union[Data, PreparedData],
  68. keep_prob: float = 1.,
  69. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  70. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  71. **kwargs):
  72. super().__init__(**kwargs)
  73. self._check_params(input_dim, output_dim, data, keep_prob,
  74. rel_activation, layer_activation)
  75. self.input_dim = input_dim
  76. self.output_dim = output_dim
  77. self.data = data
  78. self.keep_prob = keep_prob
  79. self.rel_activation = rel_activation
  80. self.layer_activation = layer_activation
  81. self.is_sparse = False
  82. self.next_layer_repr = None
  83. self.build()
  84. def build(self):
  85. self.next_layer_repr = torch.nn.ModuleList([
  86. torch.nn.ModuleList() \
  87. for _ in range(len(self.data.node_types))
  88. ])
  89. for fam in self.data.relation_families:
  90. self.build_family(fam)
  91. def build_family(self, fam) -> None:
  92. if fam.node_type_row == fam.node_type_column:
  93. self.build_fam_one_node_type(fam)
  94. else:
  95. self.build_fam_two_node_types(fam)
  96. def build_fam_one_node_type(self, fam) -> None:
  97. adjacency_matrices = [
  98. r.adjacency_matrix \
  99. for r in fam.relation_types
  100. ]
  101. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  102. self.output_dim[fam.node_type_row],
  103. adjacency_matrices,
  104. self.keep_prob,
  105. self.rel_activation)
  106. conv.input_node_type = fam.node_type_column
  107. self.next_layer_repr[fam.node_type_row].append(conv)
  108. def build_fam_two_node_types(self, fam) -> None:
  109. adjacency_matrices = [
  110. r.adjacency_matrix \
  111. for r in fam.relation_types \
  112. if r.adjacency_matrix is not None
  113. ]
  114. adjacency_matrices_backward = [
  115. r.adjacency_matrix_backward \
  116. for r in fam.relation_types \
  117. if r.adjacency_matrix_backward is not None
  118. ]
  119. conv = FastGraphConv(self.input_dim[fam.node_type_column],
  120. self.output_dim[fam.node_type_row],
  121. adjacency_matrices,
  122. self.keep_prob,
  123. self.rel_activation)
  124. conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
  125. self.output_dim[fam.node_type_column],
  126. adjacency_matrices_backward,
  127. self.keep_prob,
  128. self.rel_activation)
  129. conv.input_node_type = fam.node_type_column
  130. conv_backward.input_node_type = fam.node_type_row
  131. self.next_layer_repr[fam.node_type_row].append(conv)
  132. self.next_layer_repr[fam.node_type_column].append(conv_backward)
  133. def forward(self, prev_layer_repr):
  134. next_layer_repr = [ [] \
  135. for _ in range(len(self.data.node_types)) ]
  136. for output_node_type in range(len(self.data.node_types)):
  137. for conv in self.next_layer_repr[output_node_type]:
  138. rep = conv(prev_layer_repr[conv.input_node_type])
  139. rep = torch.sum(rep, dim=0)
  140. rep = torch.nn.functional.normalize(rep, p=2, dim=1)
  141. next_layer_repr[output_node_type].append(rep)
  142. if len(next_layer_repr[output_node_type]) == 0:
  143. next_layer_repr[output_node_type] = \
  144. torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
  145. else:
  146. next_layer_repr[output_node_type] = \
  147. sum(next_layer_repr[output_node_type])
  148. next_layer_repr[output_node_type] = \
  149. self.layer_activation(next_layer_repr[output_node_type])
  150. return next_layer_repr
  151. @staticmethod
  152. def _check_params(input_dim, output_dim, data, keep_prob,
  153. rel_activation, layer_activation):
  154. if not isinstance(input_dim, list):
  155. raise ValueError('input_dim must be a list')
  156. if not output_dim:
  157. raise ValueError('output_dim must be specified')
  158. if not isinstance(output_dim, list):
  159. output_dim = [output_dim] * len(data.node_types)
  160. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  161. raise ValueError('data must be of type Data or PreparedData')