IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

95 рядки
3.2KB

  1. import torch
  2. from .convolve import DropoutGraphConvActivation
  3. from .data import Data
  4. from .trainprep import PreparedData
  5. from typing import List, \
  6. Union, \
  7. Callable
  8. from collections import defaultdict
  9. from dataclasses import dataclass
  10. @dataclass
  11. class Convolutions(object):
  12. node_type_column: int
  13. convolutions: List[DropoutGraphConvActivation]
  14. class DecagonLayer(torch.nn.Module):
  15. def __init__(self,
  16. input_dim: List[int],
  17. output_dim: List[int],
  18. data: Union[Data, PreparedData],
  19. keep_prob: float = 1.,
  20. rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
  21. layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
  22. **kwargs):
  23. super().__init__(**kwargs)
  24. if not isinstance(input_dim, list):
  25. raise ValueError('input_dim must be a list')
  26. if not isinstance(output_dim, list):
  27. raise ValueError('output_dim must be a list')
  28. if not isinstance(data, Data) and not isinstance(data, PreparedData):
  29. raise ValueError('data must be of type Data or PreparedData')
  30. self.input_dim = input_dim
  31. self.output_dim = output_dim
  32. self.data = data
  33. self.keep_prob = float(keep_prob)
  34. self.rel_activation = rel_activation
  35. self.layer_activation = layer_activation
  36. self.is_sparse = False
  37. self.next_layer_repr = None
  38. self.build()
  39. def build(self):
  40. n = len(self.data.node_types)
  41. rel_types = self.data.relation_types
  42. self.next_layer_repr = [ [] for _ in range(n) ]
  43. for node_type_row in range(n):
  44. if node_type_row not in rel_types:
  45. continue
  46. for node_type_column in range(n):
  47. if node_type_column not in rel_types[node_type_row]:
  48. continue
  49. rels = rel_types[node_type_row][node_type_column]
  50. if len(rels) == 0:
  51. continue
  52. convolutions = []
  53. for r in rels:
  54. conv = DropoutGraphConvActivation(self.input_dim[node_type_column],
  55. self.output_dim[node_type_row], r.adjacency_matrix,
  56. self.keep_prob, self.rel_activation)
  57. convolutions.append(conv)
  58. self.next_layer_repr[node_type_row].append(
  59. Convolutions(node_type_column, convolutions))
  60. def __call__(self, prev_layer_repr):
  61. next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
  62. n = len(self.data.node_types)
  63. for node_type_row in range(n):
  64. for convolutions in self.next_layer_repr[node_type_row]:
  65. repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \
  66. for conv in convolutions.convolutions ]
  67. repr_ = sum(repr_)
  68. repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1)
  69. next_layer_repr[i].append(repr_)
  70. next_layer_repr[i] = sum(next_layer_repr[i])
  71. next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
  72. return next_layer_repr