IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

convlayer.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import torch
  2. from .convolve import DropoutGraphConvActivation
  3. from .data import Data
  4. from .trainprep import PreparedData
  5. from typing import List, \
  6. Union, \
  7. Callable
  8. from collections import defaultdict
  9. from dataclasses import dataclass
  10. class Convolutions(torch.nn.Module):
  11. node_type_column: int
  12. convolutions: torch.nn.ModuleList # [DropoutGraphConvActivation]
  13. def __init__(self, node_type_column: int,
  14. convolutions: torch.nn.ModuleList, **kwargs):
  15. super().__init__(**kwargs)
  16. self.node_type_column = node_type_column
  17. self.convolutions = convolutions
  18. class DecagonLayer(torch.nn.Module):
  19. def __init__(self,
  20. input_dim: List[int],
  21. output_dim: List[int],
  22. data: Union[Data, PreparedData],
  23. keep_prob: float = 1.,
  24. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  25. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  26. **kwargs):
  27. super().__init__(**kwargs)
  28. if not isinstance(input_dim, list):
  29. raise ValueError('input_dim must be a list')
  30. if not output_dim:
  31. raise ValueError('output_dim must be specified')
  32. if not isinstance(output_dim, list):
  33. output_dim = [output_dim] * len(data.node_types)
  34. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  35. raise ValueError('data must be of type Data or PreparedData')
  36. self.input_dim = input_dim
  37. self.output_dim = output_dim
  38. self.data = data
  39. self.keep_prob = float(keep_prob)
  40. self.rel_activation = rel_activation
  41. self.layer_activation = layer_activation
  42. self.is_sparse = False
  43. self.next_layer_repr = None
  44. self.build()
  45. def build_fam_one_node_type(self, fam):
  46. convolutions = torch.nn.ModuleList()
  47. for r in fam.relation_types:
  48. conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
  49. self.output_dim[fam.node_type_row], r.adjacency_matrix,
  50. self.keep_prob, self.rel_activation)
  51. convolutions.append(conv)
  52. self.next_layer_repr[fam.node_type_row].append(
  53. Convolutions(fam.node_type_column, convolutions))
  54. # def build_fam_two_node_types_sym(self, fam) -> None:
  55. # convolutions_row = torch.nn.ModuleList()
  56. # convolutions_column = torch.nn.ModuleList()
  57. #
  58. # if self.input_dim[fam.node_type_column] != \
  59. # self.input_dim[fam.node_type_row]:
  60. # raise ValueError('input_dim for row and column must be equal for a symmetric family')
  61. #
  62. # if self.output_dim[fam.node_type_column] != \
  63. # self.output_dim[fam.node_type_row]:
  64. # raise ValueError('output_dim for row and column must be equal for a symmetric family')
  65. #
  66. # for r in fam.relation_types:
  67. # assert r.adjacency_matrix is not None and \
  68. # r.adjacency_matrix_backward is not None
  69. # conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
  70. # self.output_dim[fam.node_type_row], r.adjacency_matrix,
  71. # self.keep_prob, self.rel_activation)
  72. # convolutions_row.append(conv)
  73. # convolutions_column.append(conv.clone(r.adjacency_matrix_backward))
  74. #
  75. # self.next_layer_repr[fam.node_type_row].append(
  76. # Convolutions(fam.node_type_column, convolutions_row))
  77. #
  78. # self.next_layer_repr[fam.node_type_column].append(
  79. # Convolutions(fam.node_type_row, convolutions_column))
  80. def build_fam_two_node_types(self, fam) -> None:
  81. convolutions_row = torch.nn.ModuleList()
  82. convolutions_column = torch.nn.ModuleList()
  83. for r in fam.relation_types:
  84. if r.adjacency_matrix is not None:
  85. conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
  86. self.output_dim[fam.node_type_row], r.adjacency_matrix,
  87. self.keep_prob, self.rel_activation)
  88. convolutions_row.append(conv)
  89. if r.adjacency_matrix_backward is not None:
  90. conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row],
  91. self.output_dim[fam.node_type_column], r.adjacency_matrix_backward,
  92. self.keep_prob, self.rel_activation)
  93. convolutions_column.append(conv)
  94. self.next_layer_repr[fam.node_type_row].append(
  95. Convolutions(fam.node_type_column, convolutions_row))
  96. self.next_layer_repr[fam.node_type_column].append(
  97. Convolutions(fam.node_type_row, convolutions_column))
  98. # def build_fam_two_node_types(self, fam) -> None:
  99. # if fam.is_symmetric:
  100. # self.build_fam_two_node_types_sym(fam)
  101. # else:
  102. # self.build_fam_two_node_types_asym(fam)
  103. def build_family(self, fam) -> None:
  104. if fam.node_type_row == fam.node_type_column:
  105. self.build_fam_one_node_type(fam)
  106. else:
  107. self.build_fam_two_node_types(fam)
  108. def build(self):
  109. self.next_layer_repr = torch.nn.ModuleList([
  110. torch.nn.ModuleList() for _ in range(len(self.data.node_types)) ])
  111. for fam in self.data.relation_families:
  112. self.build_family(fam)
  113. def __call__(self, prev_layer_repr):
  114. next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
  115. n = len(self.data.node_types)
  116. for node_type_row in range(n):
  117. for convolutions in self.next_layer_repr[node_type_row]:
  118. repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \
  119. for conv in convolutions.convolutions ]
  120. repr_ = sum(repr_)
  121. repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1)
  122. next_layer_repr[node_type_row].append(repr_)
  123. if len(next_layer_repr[node_type_row]) == 0:
  124. next_layer_repr[node_type_row] = torch.zeros(self.output_dim[node_type_row])
  125. else:
  126. next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row])
  127. next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row])
  128. return next_layer_repr