IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.5KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor
  12. @dataclass
  13. class TrainingBatch(object):
  14. vertex_type_row: int
  15. vertex_type_column: int
  16. relation_type_index: int
  17. edges: torch.Tensor
  18. class Model(torch.nn.Module):
  19. def __init__(self, data: Data, layer_dimensions: List[int],
  20. keep_prob: float,
  21. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  22. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  23. **kwargs) -> None:
  24. super().__init__(**kwargs)
  25. if not isinstance(data, Data):
  26. raise TypeError('data must be an instance of Data')
  27. if not isinstance(conv_activation, types.FunctionType):
  28. raise TypeError('conv_activation must be a function')
  29. if not isinstance(dec_activation, types.FunctionType):
  30. raise TypeError('dec_activation must be a function')
  31. self.data = data
  32. self.layer_dimensions = list(layer_dimensions)
  33. self.keep_prob = float(keep_prob)
  34. self.conv_activation = conv_activation
  35. self.dec_activation = dec_activation
  36. self.conv_weights = None
  37. self.dec_weights = None
  38. self.build()
  39. def build(self) -> None:
  40. self.conv_weights = torch.nn.ParameterDict()
  41. for i in range(len(self.layer_dimensions) - 1):
  42. in_dimension = self.layer_dimensions[i]
  43. out_dimension = self.layer_dimensions[i + 1]
  44. for _, et in self.data.edge_types.items():
  45. weight = init_glorot(in_dimension, out_dimension)
  46. self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \
  47. torch.nn.Parameter(weight)
  48. self.dec_weights = torch.nn.ParameterDict()
  49. for _, et in self.data.edge_types.items():
  50. global_interaction, local_variation = \
  51. et.decoder_factory(self.layer_dimensions[-1],
  52. len(et.adjacency_matrices))
  53. self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \
  54. torch.nn.ParameterList([
  55. torch.nn.Parameter(global_interaction),
  56. torch.nn.Parameter(local_variation)
  57. ])
  58. def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]:
  59. edges = []
  60. cur_edges = batch.edges
  61. for _ in range(len(self.layer_dimensions) - 1):
  62. edges.append(cur_edges)
  63. key = (batch.vertex_type_row, batch.vertex_type_column)
  64. tot_conn = self.data.relation_types[key].total_connectivity
  65. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  66. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  67. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  68. col = batch.vertex_type_column
  69. rows = batch.edges[:, 0]
  70. columns = batch.edges[:, 1].sum(dim=0).flatten()
  71. columns = torch.nonzero(columns)
  72. for i in range(len(self.layer_dimensions) - 1):
  73. pass # columns =
  74. # TODO: finish
  75. return None
  76. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  77. col = batch.vertex_type_column
  78. batch.edges[:, 1]
  79. res = {}
  80. for _, et in self.data.edge_types.items():
  81. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  82. res[et.vertex_type_row, et.vertex_type_column] = \
  83. [ self.temporary_adjacency_matrix(adj_mat, batch,
  84. et.total_connectivity) \
  85. for adj_mat in et.adjacency_matrices ]
  86. return res
  87. def forward(self, initial_repr: List[torch.Tensor],
  88. batch: TrainingBatch) -> torch.Tensor:
  89. if not isinstance(initial_repr, list):
  90. raise TypeError('initial_repr must be a list')
  91. if len(initial_repr) != len(self.data.vertex_types):
  92. raise ValueError('initial_repr must contain representations for all vertex types')
  93. if not isinstance(batch, TrainingBatch):
  94. raise TypeError('batch must be an instance of TrainingBatch')
  95. adj_matrices = self.temporary_adjacency_matrices(batch)
  96. row_vertices = initial_repr[batch.vertex_type_row]
  97. column_vertices = initial_repr[batch.vertex_type_column]