IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
6.0KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data
  5. from .trainprep import PreparedData
  6. import torch
  7. from .weights import init_glorot
  8. def _cat(matrices: List[torch.Tensor]):
  9. if len(matrices) == 0:
  10. raise ValueError('Empty list passed to _cat()')
  11. n = sum(a.is_sparse for a in matrices)
  12. if n != 0 and n != len(matrices):
  13. raise ValueError('All matrices must have the same layout (dense or sparse)')
  14. if not all(a.shape[1:] == matrices[0].shape[1:]):
  15. raise ValueError('All matrices must have the same dimensions apart from dimension 0')
  16. if not matrices[0].is_sparse:
  17. return torch.cat(matrices)
  18. total_rows = sum(a.shape[0] for a in matrices)
  19. indices = []
  20. values = []
  21. row_offset = 0
  22. for a in matrices:
  23. ind = a._indices().clone()
  24. val = a._values()
  25. ind[0] += row_offset
  26. ind = ind.transpose(0, 1)
  27. indices.append(ind)
  28. values.append(val)
  29. row_offset += a.shape[0]
  30. indices = torch.cat(indices).transpose(0, 1)
  31. values = torch.cat(values)
  32. res = _sparse_coo_tensor(indices, values)
  33. return res
  34. class FastGraphConv(torch.nn.Module):
  35. def __init__(self,
  36. in_channels: int,
  37. out_channels: int,
  38. adjacency_matrix: List[torch.Tensor],
  39. **kwargs):
  40. self.in_channels = in_channels
  41. self.out_channels = out_channels
  42. self.weight = torch.cat([
  43. init_glorot(in_channels, out_channels) \
  44. for _ in adjacency_matrix
  45. ], dim=1)
  46. self.adjacency_matrix = _cat(adjacency_matrix)
  47. def forward(self, x):
  48. x = torch.sparse.mm(x, self.weight) \
  49. if x.is_sparse \
  50. else torch.mm(x, self.weight)
  51. x = torch.sparse.mm(self.adjacency_matrix, x) \
  52. if self.adjacency_matrix.is_sparse \
  53. else torch.mm(self.adjacency_matrix, x)
  54. return x
  55. class FastConvLayer(torch.nn.Module):
  56. adjacency_matrix: List[torch.Tensor]
  57. adjacency_matrix_backward: List[torch.Tensor]
  58. weight: List[torch.Tensor]
  59. weight_backward: List[torch.Tensor]
  60. def __init__(self,
  61. input_dim: List[int],
  62. output_dim: List[int],
  63. data: Union[Data, PreparedData],
  64. keep_prob: float = 1.,
  65. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x
  66. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  67. **kwargs):
  68. super().__init__(**kwargs)
  69. self._check_params(input_dim, output_dim, data, keep_prob,
  70. rel_activation, layer_activation)
  71. self.input_dim = input_dim
  72. self.output_dim = output_dim
  73. self.data = data
  74. self.keep_prob = keep_prob
  75. self.rel_activation = rel_activation
  76. self.layer_activation = layer_activation
  77. self.adjacency_matrix = None
  78. self.adjacency_matrix_backward = None
  79. self.weight = None
  80. self.weight_backward = None
  81. self.build()
  82. def build(self):
  83. self.adjacency_matrix = []
  84. self.adjacency_matrix_backward = []
  85. self.weight = []
  86. self.weight_backward = []
  87. for fam in self.data.relation_families:
  88. adj_mat = [ rel.adjacency_matrix \
  89. for rel in fam.relation_types \
  90. if rel.adjacency_matrix is not None ]
  91. adj_mat_back = [ rel.adjacency_matrix_backward \
  92. for rel in fam.relation_types \
  93. if rel.adjacency_matrix_backward is not None ]
  94. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  95. self.output_dim[fam.node_type_row]) \
  96. for _ in range(len(adj_mat)) ]
  97. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  98. self.output_dim[fam.node_type_row]) \
  99. for _ in range(len(adj_mat_back)) ]
  100. adj_mat = torch.cat(adj_mat) \
  101. if len(adj_mat) > 0 \
  102. else None
  103. adj_mat_back = torch.cat(adj_mat_back) \
  104. if len(adj_mat_back) > 0 \
  105. else None
  106. self.adjacency_matrix.append(adj_mat)
  107. self.adjacency_matrix_backward.append(adj_mat_back)
  108. self.weight.append(weight)
  109. self.weight_back.append(weight_back)
  110. def forward(self, prev_layer_repr):
  111. for i, fam in enumerate(self.data.relation_families):
  112. repr_row = prev_layer_repr[fam.node_type_row]
  113. repr_column = prev_layer_repr[fam.node_type_column]
  114. adj_mat = self.adjacency_matrix[i]
  115. adj_mat_back = self.adjacency_matrix_backward[i]
  116. if adj_mat is not None:
  117. x = dropout(repr_column, keep_prob=self.keep_prob)
  118. x = torch.sparse.mm(x, self.weight[i]) \
  119. if x.is_sparse \
  120. else torch.mm(x, self.weight[i])
  121. x = torch.sparse.mm(adj_mat, repr_row) \
  122. if adj_mat.is_sparse \
  123. else torch.mm(adj_mat, repr_row)
  124. x = self.rel_activation(x)
  125. x = x.view(len(fam.relation_types), len(repr_row), -1)
  126. if adj_mat_back is not None:
  127. x = torch.sparse.mm(adj_mat_back, repr_row) \
  128. if adj_mat_back.is_sparse \
  129. else torch.mm(adj_mat_back, repr_row)
  130. @staticmethod
  131. def _check_params(input_dim, output_dim, data, keep_prob,
  132. rel_activation, layer_activation):
  133. if not isinstance(input_dim, list):
  134. raise ValueError('input_dim must be a list')
  135. if not output_dim:
  136. raise ValueError('output_dim must be specified')
  137. if not isinstance(output_dim, list):
  138. output_dim = [output_dim] * len(data.node_types)
  139. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  140. raise ValueError('data must be of type Data or PreparedData')