IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
4.2KB

  1. from typing import List, \
  2. Union, \
  3. Callable
  4. from .data import Data
  5. from .trainprep import PreparedData
  6. import torch
  7. from .weights import init_glorot
  8. class FastConvLayer(torch.nn.Module):
  9. adjacency_matrix: List[torch.Tensor]
  10. adjacency_matrix_backward: List[torch.Tensor]
  11. weight: List[torch.Tensor]
  12. weight_backward: List[torch.Tensor]
  13. def __init__(self,
  14. input_dim: List[int],
  15. output_dim: List[int],
  16. data: Union[Data, PreparedData],
  17. keep_prob: float = 1.,
  18. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x
  19. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  20. **kwargs):
  21. super().__init__(**kwargs)
  22. self._check_params(input_dim, output_dim, data, keep_prob,
  23. rel_activation, layer_activation)
  24. self.input_dim = input_dim
  25. self.output_dim = output_dim
  26. self.data = data
  27. self.keep_prob = keep_prob
  28. self.rel_activation = rel_activation
  29. self.layer_activation = layer_activation
  30. self.adjacency_matrix = None
  31. self.adjacency_matrix_backward = None
  32. self.weight = None
  33. self.weight_backward = None
  34. self.build()
  35. def build(self):
  36. self.adjacency_matrix = []
  37. self.adjacency_matrix_backward = []
  38. self.weight = []
  39. self.weight_backward = []
  40. for fam in self.data.relation_families:
  41. adj_mat = [ rel.adjacency_matrix \
  42. for rel in fam.relation_types \
  43. if rel.adjacency_matrix is not None ]
  44. adj_mat_back = [ rel.adjacency_matrix_backward \
  45. for rel in fam.relation_types \
  46. if rel.adjacency_matrix_backward is not None ]
  47. weight = [ init_glorot(self.input_dim[fam.node_type_column],
  48. self.output_dim[fam.node_type_row]) \
  49. for _ in range(len(adj_mat)) ]
  50. weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
  51. self.output_dim[fam.node_type_row]) \
  52. for _ in range(len(adj_mat_back)) ]
  53. adj_mat = torch.cat(adj_mat) \
  54. if len(adj_mat) > 0 \
  55. else None
  56. adj_mat_back = torch.cat(adj_mat_back) \
  57. if len(adj_mat_back) > 0 \
  58. else None
  59. self.adjacency_matrix.append(adj_mat)
  60. self.adjacency_matrix_backward.append(adj_mat_back)
  61. self.weight.append(weight)
  62. self.weight_back.append(weight_back)
  63. def forward(self, prev_layer_repr):
  64. for i, fam in enumerate(self.data.relation_families):
  65. repr_row = prev_layer_repr[fam.node_type_row]
  66. repr_column = prev_layer_repr[fam.node_type_column]
  67. adj_mat = self.adjacency_matrix[i]
  68. adj_mat_back = self.adjacency_matrix_backward[i]
  69. if adj_mat is not None:
  70. x = dropout(repr_column, keep_prob=self.keep_prob)
  71. x = torch.sparse.mm(x, self.weight[i]) \
  72. if x.is_sparse \
  73. else torch.mm(x, self.weight[i])
  74. x = torch.sparse.mm(adj_mat, repr_row) \
  75. if adj_mat.is_sparse \
  76. else torch.mm(adj_mat, repr_row)
  77. x = self.rel_activation(x)
  78. x = x.view(len(fam.relation_types), len(repr_row), -1)
  79. if adj_mat_back is not None:
  80. x = torch.sparse.mm(adj_mat_back, repr_row) \
  81. if adj_mat_back.is_sparse \
  82. else torch.mm(adj_mat_back, repr_row)
  83. @staticmethod
  84. def _check_params(input_dim, output_dim, data, keep_prob,
  85. rel_activation, layer_activation):
  86. if not isinstance(input_dim, list):
  87. raise ValueError('input_dim must be a list')
  88. if not output_dim:
  89. raise ValueError('output_dim must be specified')
  90. if not isinstance(output_dim, list):
  91. output_dim = [output_dim] * len(data.node_types)
  92. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  93. raise ValueError('data must be of type Data or PreparedData')