IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
4.5KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable
  10. from .util import _sparse_coo_tensor
  11. @dataclass
  12. class TrainingBatch(object):
  13. vertex_type_row: int
  14. vertex_type_column: int
  15. relation_type_index: int
  16. edges: torch.Tensor
  17. class Model(torch.nn.Module):
  18. def __init__(self, data: Data, layer_dimensions: List[int],
  19. keep_prob: float,
  20. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  21. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  22. **kwargs) -> None:
  23. super().__init__(**kwargs)
  24. if not isinstance(data, Data):
  25. raise TypeError('data must be an instance of Data')
  26. if not isinstance(conv_activation, types.FunctionType):
  27. raise TypeError('conv_activation must be a function')
  28. if not isinstance(dec_activation, types.FunctionType):
  29. raise TypeError('dec_activation must be a function')
  30. self.data = data
  31. self.layer_dimensions = list(layer_dimensions)
  32. self.keep_prob = float(keep_prob)
  33. self.conv_activation = conv_activation
  34. self.dec_activation = dec_activation
  35. self.conv_weights = None
  36. self.dec_weights = None
  37. self.build()
  38. def build(self) -> None:
  39. self.conv_weights = torch.nn.ParameterDict()
  40. for i in range(len(self.layer_dimensions) - 1):
  41. in_dimension = self.layer_dimensions[i]
  42. out_dimension = self.layer_dimensions[i + 1]
  43. for _, et in self.data.edge_types.items():
  44. weight = init_glorot(in_dimension, out_dimension)
  45. self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \
  46. torch.nn.Parameter(weight)
  47. self.dec_weights = torch.nn.ParameterDict()
  48. for _, et in self.data.edge_types.items():
  49. global_interaction, local_variation = \
  50. et.decoder_factory(self.layer_dimensions[-1],
  51. len(et.adjacency_matrices))
  52. self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \
  53. torch.nn.ParameterList([
  54. torch.nn.Parameter(global_interaction),
  55. torch.nn.Parameter(local_variation)
  56. ])
  57. def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor,
  58. rows: torch.Tensor) -> torch.Tensor:
  59. adj_mat = adjacency_matrix.coalesce()
  60. adj_mat = torch.index_select(adj_mat, 0, rows)
  61. adj_mat = adj_mat.coalesce()
  62. indices = adj_mat.indices()
  63. indices[0] = rows
  64. adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape)
  65. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  66. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  67. col = batch.vertex_type_column
  68. rows = batch.edges[:, 0]
  69. columns = batch.edges[:, 1].sum(dim=0).flatten()
  70. columns = torch.nonzero(columns)
  71. for i in range(len(self.layer_dimensions) - 1):
  72. columns =
  73. def temporary_adjacency_matrices(self, batch: TrainingBatch) ->
  74. Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  75. col = batch.vertex_type_column
  76. batch.edges[:, 1]
  77. res = {}
  78. for _, et in self.data.edge_types.items():
  79. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  80. res[et.vertex_type_row, et.vertex_type_column] = \
  81. [ self.temporary_adjacency_matrix(adj_mat, batch,
  82. et.total_connectivity) \
  83. for adj_mat in et.adjacency_matrices ]
  84. return res
  85. def forward(self, initial_repr: List[torch.Tensor],
  86. batch: TrainingBatch) -> torch.Tensor:
  87. if not isinstance(initial_repr, list):
  88. raise TypeError('initial_repr must be a list')
  89. if len(initial_repr) != len(self.data.vertex_types):
  90. raise ValueError('initial_repr must contain representations for all vertex types')
  91. if not isinstance(batch, TrainingBatch):
  92. raise TypeError('batch must be an instance of TrainingBatch')
  93. adj_matrices = self.temporary_adjacency_matrices(batch)
  94. row_vertices = initial_repr[batch.vertex_type_row]
  95. column_vertices = initial_repr[batch.vertex_type_column]