IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.7KB

  1. #
  2. # This module implements a single layer of the Decagon
  3. # model. This is going to be already quite complex, as
  4. # we will be using all the graph convolutional building
  5. # blocks.
  6. #
  7. # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
  8. # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
  9. #
  10. # N{r}^{i} - set of neighbors of node i under relation r
  11. # W_{r}^(k) - relation-type specific weight matrix
  12. # h_{i}^(k) - hidden state of node i in layer k
  13. # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
  14. # of the representation in k-th layer
  15. # ϕ - activation function
  16. # c_{r}^{ij} - normalization constants
  17. # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  18. # c_{r}^{i} - normalization constants
  19. # c_{r}^{i} = 1/|N_{r}^{i}|
  20. #
  21. import torch
  22. from .convolve import SparseMultiDGCA
  23. class InputLayer(torch.nn.Module):
  24. def __init__(self, data, dimensionality=None, **kwargs):
  25. super().__init__(**kwargs)
  26. self.data = data
  27. dimensionality = dimensionality or \
  28. list(map(lambda a: a.count, data.node_types))
  29. if not isinstance(dimensionality, list):
  30. dimensionality = [dimensionality,] * len(self.data.node_types)
  31. self.dimensionality = dimensionality
  32. self.node_reps = None
  33. self.build()
  34. def build(self):
  35. self.node_reps = []
  36. for i, nt in enumerate(self.data.node_types):
  37. reps = torch.rand(nt.count, self.dimensionality[i])
  38. reps = torch.nn.Parameter(reps)
  39. self.register_parameter('node_reps[%d]' % i, reps)
  40. self.node_reps.append(reps)
  41. def forward(self):
  42. return self.node_reps
  43. def __repr__(self):
  44. s = ''
  45. s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality
  46. s += ' # of node types: %d\n' % len(self.data.node_types)
  47. for nt in self.data.node_types:
  48. s += ' - %s (%d)\n' % (nt.name, nt.count)
  49. return s.strip()
  50. class DecagonLayer(torch.nn.Module):
  51. def __init__(self, data,
  52. input_dim, output_dim,
  53. keep_prob=1.,
  54. rel_activation=lambda x: x,
  55. layer_activation=torch.nn.functional.relu,
  56. **kwargs):
  57. super().__init__(**kwargs)
  58. self.data = data
  59. self.input_dim = input_dim
  60. self.output_dim = output_dim
  61. self.keep_prob = keep_prob
  62. self.rel_activation = rel_activation
  63. self.layer_activation = layer_activation
  64. self.convolutions = None
  65. self.build()
  66. def build(self):
  67. self.convolutions = {}
  68. for key in self.data.relation_types.keys():
  69. adjacency_matrices = \
  70. self.data.get_adjacency_matrices(*key)
  71. self.convolutions[key] = SparseMultiDGCA(self.input_dim,
  72. self.output_dim, adjacency_matrices,
  73. self.keep_prob, self.rel_activation)
  74. # for node_type_row, node_type_col in enumerate(self.data.node_
  75. # if rt.node_type_row == i or rt.node_type_col == i:
  76. def __call__(self, prev_layer_repr):
  77. new_layer_repr = []
  78. for i, nt in enumerate(self.data.node_types):
  79. new_repr = []
  80. for key in self.data.relation_types.keys():
  81. nt_row, nt_col = key
  82. if nt_row != i and nt_col != i:
  83. continue
  84. if nt_row == i:
  85. x = prev_layer_repr[nt_col]
  86. else:
  87. x = prev_layer_repr[nt_row]
  88. conv = self.convolutions[key]
  89. new_repr.append(conv(x))
  90. new_repr = sum(new_repr)
  91. new_layer_repr.append(new_repr)
  92. return new_layer_repr