IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

127 行
4.9KB

  1. import torch
  2. from .convolve import DropoutGraphConvActivation
  3. from .data import Data
  4. from .trainprep import PreparedData
  5. from typing import List, \
  6. Union, \
  7. Callable
  8. from collections import defaultdict
  9. from dataclasses import dataclass
  10. import time
  11. class Convolutions(torch.nn.Module):
  12. node_type_column: int
  13. convolutions: torch.nn.ModuleList # [DropoutGraphConvActivation]
  14. def __init__(self, node_type_column: int,
  15. convolutions: torch.nn.ModuleList, **kwargs):
  16. super().__init__(**kwargs)
  17. self.node_type_column = node_type_column
  18. self.convolutions = convolutions
  19. class DecagonLayer(torch.nn.Module):
  20. def __init__(self,
  21. input_dim: List[int],
  22. output_dim: List[int],
  23. data: Union[Data, PreparedData],
  24. keep_prob: float = 1.,
  25. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  26. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  27. **kwargs):
  28. super().__init__(**kwargs)
  29. if not isinstance(input_dim, list):
  30. raise ValueError('input_dim must be a list')
  31. if not output_dim:
  32. raise ValueError('output_dim must be specified')
  33. if not isinstance(output_dim, list):
  34. output_dim = [output_dim] * len(data.node_types)
  35. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  36. raise ValueError('data must be of type Data or PreparedData')
  37. self.input_dim = input_dim
  38. self.output_dim = output_dim
  39. self.data = data
  40. self.keep_prob = float(keep_prob)
  41. self.rel_activation = rel_activation
  42. self.layer_activation = layer_activation
  43. self.is_sparse = False
  44. self.next_layer_repr = None
  45. self.build()
  46. def build_fam_one_node_type(self, fam):
  47. convolutions = torch.nn.ModuleList()
  48. for r in fam.relation_types:
  49. conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
  50. self.output_dim[fam.node_type_row], r.adjacency_matrix,
  51. self.keep_prob, self.rel_activation)
  52. convolutions.append(conv)
  53. self.next_layer_repr[fam.node_type_row].append(
  54. Convolutions(fam.node_type_column, convolutions))
  55. def build_fam_two_node_types(self, fam) -> None:
  56. convolutions_row = torch.nn.ModuleList()
  57. convolutions_column = torch.nn.ModuleList()
  58. for r in fam.relation_types:
  59. if r.adjacency_matrix is not None:
  60. conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
  61. self.output_dim[fam.node_type_row], r.adjacency_matrix,
  62. self.keep_prob, self.rel_activation)
  63. convolutions_row.append(conv)
  64. if r.adjacency_matrix_backward is not None:
  65. conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row],
  66. self.output_dim[fam.node_type_column], r.adjacency_matrix_backward,
  67. self.keep_prob, self.rel_activation)
  68. convolutions_column.append(conv)
  69. self.next_layer_repr[fam.node_type_row].append(
  70. Convolutions(fam.node_type_column, convolutions_row))
  71. self.next_layer_repr[fam.node_type_column].append(
  72. Convolutions(fam.node_type_row, convolutions_column))
  73. def build_family(self, fam) -> None:
  74. if fam.node_type_row == fam.node_type_column:
  75. self.build_fam_one_node_type(fam)
  76. else:
  77. self.build_fam_two_node_types(fam)
  78. def build(self):
  79. self.next_layer_repr = torch.nn.ModuleList([
  80. torch.nn.ModuleList() for _ in range(len(self.data.node_types)) ])
  81. for fam in self.data.relation_families:
  82. self.build_family(fam)
  83. def __call__(self, prev_layer_repr):
  84. t = time.time()
  85. next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
  86. n = len(self.data.node_types)
  87. for node_type_row in range(n):
  88. for convolutions in self.next_layer_repr[node_type_row]:
  89. repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \
  90. for conv in convolutions.convolutions ]
  91. repr_ = sum(repr_)
  92. repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1)
  93. next_layer_repr[node_type_row].append(repr_)
  94. if len(next_layer_repr[node_type_row]) == 0:
  95. next_layer_repr[node_type_row] = torch.zeros(self.output_dim[node_type_row])
  96. else:
  97. next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row])
  98. next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row])
  99. # print('DecagonLayer.forward() took', time.time() - t)
  100. return next_layer_repr