IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

61 строка
1.9KB

  1. #
  2. # This module implements a single layer of the Decagon
  3. # model. This is going to be already quite complex, as
  4. # we will be using all the graph convolutional building
  5. # blocks.
  6. #
  7. # h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
  8. # W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
  9. #
  10. # N{r}^{i} - set of neighbors of node i under relation r
  11. # W_{r}^(k) - relation-type specific weight matrix
  12. # h_{i}^(k) - hidden state of node i in layer k
  13. # h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
  14. # of the representation in k-th layer
  15. # ϕ - activation function
  16. # c_{r}^{ij} - normalization constants
  17. # c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
  18. # c_{r}^{i} - normalization constants
  19. # c_{r}^{i} = 1/|N_{r}^{i}|
  20. #
  21. import torch
  22. class InputLayer(torch.nn.Module):
  23. def __init__(self, data, dimensionality=32, **kwargs):
  24. super().__init__(**kwargs)
  25. self.data = data
  26. self.dimensionality = dimensionality
  27. self.node_reps = None
  28. self.build()
  29. def build(self):
  30. self.node_reps = []
  31. for i, nt in enumerate(self.data.node_types):
  32. reps = torch.rand(nt.count, self.dimensionality)
  33. reps = torch.nn.Parameter(reps)
  34. self.register_parameter('node_reps[%d]' % i, reps)
  35. self.node_reps.append(reps)
  36. def forward(self):
  37. return self.node_reps
  38. def __repr__(self):
  39. s = ''
  40. s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality
  41. s += ' # of node types: %d\n' % len(self.data.node_types)
  42. for nt in self.data.node_types:
  43. s += ' - %s (%d)\n' % (nt.name, nt.count)
  44. return s.strip()
  45. class DecagonLayer(torch.nn.Module):
  46. def __init__(self, data, **kwargs):
  47. super().__init__(**kwargs)
  48. self.data = data
  49. def __call__(self, previous_layer):
  50. pass