IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

172 Zeilen
6.0KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data
  5. from .trainprep import PreparedData
  6. import torch
  7. from .weights import init_glorot
  8. def _cat(matrices: List[torch.Tensor]):
  9. if len(matrices) == 0:
  10. raise ValueError('Empty list passed to _cat()')
  11. n = sum(a.is_sparse for a in matrices)
  12. if n != 0 and n != len(matrices):
  13. raise ValueError('All matrices must have the same layout (dense or sparse)')
  14. if not all(a.shape[1:] == matrices[0].shape[1:]):
  15. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  16. if not matrices[0].is_sparse:
  17. return torch.cat(matrices)
  18. total_rows = sum(a.shape[0] for a in matrices)
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. for a in matrices:
  23. ind = a._indices().clone()
  24. val = a._values()
  25. ind[0] += row_offset
  26. ind = ind.transpose(0, 1)
  27. indices.append(ind)
  28. values.append(val)
  29. row_offset += a.shape[0]
  30. indices = torch.cat(indices).transpose(0, 1)
  31. values = torch.cat(values)
  32. res = _sparse_coo_tensor(indices, values)
  33. return res
  34. class FastGraphConv(torch.nn.Module):
  35. def __init__(self,
  36. in_channels: int,
  37. out_channels: int,
  38. adjacency_matrix: List[torch.Tensor],
  39. **kwargs):
  40. self.in_channels = in_channels
  41. self.out_channels = out_channels
  42. self.weight = torch.cat([
  43. init_glorot(in_channels, out_channels) \
  44. for _ in adjacency_matrix
  45. ], dim=1)
  46. self.adjacency_matrix = _cat(adjacency_matrix)
  47. def forward(self, x):
  48. x = torch.sparse.mm(x, self.weight) \
  49. if x.is_sparse \
  50. else torch.mm(x, self.weight)
  51. x = torch.sparse.mm(self.adjacency_matrix, x) \
  52. if self.adjacency_matrix.is_sparse \
  53. else torch.mm(self.adjacency_matrix, x)
  54. return x
  55. class FastConvLayer(torch.nn.Module):
  56. adjacency_matrix: List[torch.Tensor]
  57. adjacency_matrix_backward: List[torch.Tensor]
  58. weight: List[torch.Tensor]
  59. weight_backward: List[torch.Tensor]
  60. def __init__(self,
  61. input_dim: List[int],
  62. output_dim: List[int],
  63. data: Union[Data, PreparedData],
  64. keep_prob: float = 1.,
  65. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x
  66. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  67. **kwargs):
  68. super().__init__(**kwargs)
  69. self._check_params(input_dim, output_dim, data, keep_prob,
  70. rel_activation, layer_activation)
  71. self.input_dim = input_dim
  72. self.output_dim = output_dim
  73. self.data = data
  74. self.keep_prob = keep_prob
  75. self.rel_activation = rel_activation
  76. self.layer_activation = layer_activation
  77. self.adjacency_matrix = None
  78. self.adjacency_matrix_backward = None
  79. self.weight = None
  80. self.weight_backward = None
  81. self.build()
  82. def build(self):
  83. self.adjacency_matrix = []
  84. self.adjacency_matrix_backward = []
  85. self.weight = []
  86. self.weight_backward = []
  87. for fam in self.data.relation_families:
  88. adj_mat = [ rel.adjacency_matrix \
  89. for rel in fam.relation_types \
  90. if rel.adjacency_matrix is not None ]
  91. adj_mat_back = [ rel.adjacency_matrix_backward \
  92. for rel in fam.relation_types \
  93. if rel.adjacency_matrix_backward is not None ]
  94. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  95. self.output_dim[fam.node_type_row]) \
  96. for _ in range(len(adj_mat)) ]
  97. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  98. self.output_dim[fam.node_type_row]) \
  99. for _ in range(len(adj_mat_back)) ]
  100. adj_mat = torch.cat(adj_mat) \
  101. if len(adj_mat) > 0 \
  102. else None
  103. adj_mat_back = torch.cat(adj_mat_back) \
  104. if len(adj_mat_back) > 0 \
  105. else None
  106. self.adjacency_matrix.append(adj_mat)
  107. self.adjacency_matrix_backward.append(adj_mat_back)
  108. self.weight.append(weight)
  109. self.weight_back.append(weight_back)
  110. def forward(self, prev_layer_repr):
  111. for i, fam in enumerate(self.data.relation_families):
  112. repr_row = prev_layer_repr[fam.node_type_row]
  113. repr_column = prev_layer_repr[fam.node_type_column]
  114. adj_mat = self.adjacency_matrix[i]
  115. adj_mat_back = self.adjacency_matrix_backward[i]
  116. if adj_mat is not None:
  117. x = dropout(repr_column, keep_prob=self.keep_prob)
  118. x = torch.sparse.mm(x, self.weight[i]) \
  119. if x.is_sparse \
  120. else torch.mm(x, self.weight[i])
  121. x = torch.sparse.mm(adj_mat, repr_row) \
  122. if adj_mat.is_sparse \
  123. else torch.mm(adj_mat, repr_row)
  124. x = self.rel_activation(x)
  125. x = x.view(len(fam.relation_types), len(repr_row), -1)
  126. if adj_mat_back is not None:
  127. x = torch.sparse.mm(adj_mat_back, repr_row) \
  128. if adj_mat_back.is_sparse \
  129. else torch.mm(adj_mat_back, repr_row)
  130. @staticmethod
  131. def _check_params(input_dim, output_dim, data, keep_prob,
  132. rel_activation, layer_activation):
  133. if not isinstance(input_dim, list):
  134. raise ValueError('input_dim must be a list')
  135. if not output_dim:
  136. raise ValueError('output_dim must be specified')
  137. if not isinstance(output_dim, list):
  138. output_dim = [output_dim] * len(data.node_types)
  139. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  140. raise ValueError('data must be of type Data or PreparedData')