IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
5.8KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor
  12. @dataclass
  13. class TrainingBatch(object):
  14. vertex_type_row: int
  15. vertex_type_column: int
  16. relation_type_index: int
  17. edges: torch.Tensor
  18. def _per_layer_required_rows(data: Data, batch: TrainingBatch,
  19. num_layers: int) -> List[List[EdgeType]]:
  20. Q = [
  21. ( batch.vertex_type_row, batch.edges[:, 0] ),
  22. ( batch.vertex_type_column, batch.edges[:, 1] )
  23. ]
  24. print('Q:', Q)
  25. res = []
  26. for _ in range(num_layers):
  27. R = []
  28. required_rows = [ [] for _ in range(len(data.vertex_types)) ]
  29. for vertex_type, vertices in Q:
  30. for et in data.edge_types.values():
  31. if et.vertex_type_row == vertex_type:
  32. required_rows[vertex_type].append(vertices)
  33. indices = et.total_connectivity.indices()
  34. mask = torch.zeros(et.total_connectivity.shape[0])
  35. mask[vertices] = 1
  36. mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
  37. R.append((et.vertex_type_column,
  38. indices[1, mask]))
  39. else:
  40. pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
  41. required_rows = [ torch.unique(torch.cat(x)) \
  42. if len(x) > 0 \
  43. else None \
  44. for x in required_rows ]
  45. res.append(required_rows)
  46. Q = R
  47. return res
  48. class Model(torch.nn.Module):
  49. def __init__(self, data: Data, layer_dimensions: List[int],
  50. keep_prob: float,
  51. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  52. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  53. **kwargs) -> None:
  54. super().__init__(**kwargs)
  55. if not isinstance(data, Data):
  56. raise TypeError('data must be an instance of Data')
  57. if not isinstance(conv_activation, types.FunctionType):
  58. raise TypeError('conv_activation must be a function')
  59. if not isinstance(dec_activation, types.FunctionType):
  60. raise TypeError('dec_activation must be a function')
  61. self.data = data
  62. self.layer_dimensions = list(layer_dimensions)
  63. self.keep_prob = float(keep_prob)
  64. self.conv_activation = conv_activation
  65. self.dec_activation = dec_activation
  66. self.conv_weights = None
  67. self.dec_weights = None
  68. self.build()
  69. def build(self) -> None:
  70. self.conv_weights = torch.nn.ParameterDict()
  71. for i in range(len(self.layer_dimensions) - 1):
  72. in_dimension = self.layer_dimensions[i]
  73. out_dimension = self.layer_dimensions[i + 1]
  74. for _, et in self.data.edge_types.items():
  75. weight = init_glorot(in_dimension, out_dimension)
  76. self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \
  77. torch.nn.Parameter(weight)
  78. self.dec_weights = torch.nn.ParameterDict()
  79. for _, et in self.data.edge_types.items():
  80. global_interaction, local_variation = \
  81. et.decoder_factory(self.layer_dimensions[-1],
  82. len(et.adjacency_matrices))
  83. self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \
  84. torch.nn.ParameterList([
  85. torch.nn.Parameter(global_interaction),
  86. torch.nn.Parameter(local_variation)
  87. ])
  88. def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]:
  89. edges = []
  90. cur_edges = batch.edges
  91. for _ in range(len(self.layer_dimensions) - 1):
  92. edges.append(cur_edges)
  93. key = (batch.vertex_type_row, batch.vertex_type_column)
  94. tot_conn = self.data.relation_types[key].total_connectivity
  95. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  96. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  97. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  98. col = batch.vertex_type_column
  99. rows = batch.edges[:, 0]
  100. columns = batch.edges[:, 1].sum(dim=0).flatten()
  101. columns = torch.nonzero(columns)
  102. for i in range(len(self.layer_dimensions) - 1):
  103. pass # columns =
  104. # TODO: finish
  105. return None
  106. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  107. col = batch.vertex_type_column
  108. batch.edges[:, 1]
  109. res = {}
  110. for _, et in self.data.edge_types.items():
  111. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  112. res[et.vertex_type_row, et.vertex_type_column] = \
  113. [ self.temporary_adjacency_matrix(adj_mat, batch,
  114. et.total_connectivity) \
  115. for adj_mat in et.adjacency_matrices ]
  116. return res
  117. def forward(self, initial_repr: List[torch.Tensor],
  118. batch: TrainingBatch) -> torch.Tensor:
  119. if not isinstance(initial_repr, list):
  120. raise TypeError('initial_repr must be a list')
  121. if len(initial_repr) != len(self.data.vertex_types):
  122. raise ValueError('initial_repr must contain representations for all vertex types')
  123. if not isinstance(batch, TrainingBatch):
  124. raise TypeError('batch must be an instance of TrainingBatch')
  125. adj_matrices = self.temporary_adjacency_matrices(batch)
  126. row_vertices = initial_repr[batch.vertex_type_row]
  127. column_vertices = initial_repr[batch.vertex_type_column]