IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

229 lignes
8.4KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor, \
  12. _sparse_diag_cat, \
  13. _mm
  14. from .normalize import norm_adj_mat_one_node_type, \
  15. norm_adj_mat_two_node_types
  16. from .dropout import dropout
  17. @dataclass
  18. class TrainingBatch(object):
  19. vertex_type_row: int
  20. vertex_type_column: int
  21. relation_type_index: int
  22. edges: torch.Tensor
  23. def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
  24. num_layers: int) -> List[List[EdgeType]]:
  25. Q = [
  26. ( batch.vertex_type_row, batch.edges[:, 0] ),
  27. ( batch.vertex_type_column, batch.edges[:, 1] )
  28. ]
  29. print('Q:', Q)
  30. res = []
  31. for _ in range(num_layers):
  32. R = []
  33. required_rows = [ [] for _ in range(len(data.vertex_types)) ]
  34. for vertex_type, vertices in Q:
  35. for et in data.edge_types.values():
  36. if et.vertex_type_row == vertex_type:
  37. required_rows[vertex_type].append(vertices)
  38. indices = et.total_connectivity.indices()
  39. mask = torch.zeros(et.total_connectivity.shape[0])
  40. mask[vertices] = 1
  41. mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
  42. R.append((et.vertex_type_column,
  43. indices[1, mask]))
  44. else:
  45. pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
  46. required_rows = [ torch.unique(torch.cat(x)) \
  47. if len(x) > 0 \
  48. else None \
  49. for x in required_rows ]
  50. res.append(required_rows)
  51. Q = R
  52. return res
  53. class Model(torch.nn.Module):
  54. def __init__(self, data: Data, layer_dimensions: List[int],
  55. keep_prob: float,
  56. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  57. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  58. **kwargs) -> None:
  59. super().__init__(**kwargs)
  60. if not isinstance(data, Data):
  61. raise TypeError('data must be an instance of Data')
  62. if not callable(conv_activation):
  63. raise TypeError('conv_activation must be callable')
  64. if not callable(dec_activation):
  65. raise TypeError('dec_activation must be callable')
  66. self.data = data
  67. self.layer_dimensions = list(layer_dimensions)
  68. self.keep_prob = float(keep_prob)
  69. self.conv_activation = conv_activation
  70. self.dec_activation = dec_activation
  71. self.adj_matrices = None
  72. self.conv_weights = None
  73. self.dec_weights = None
  74. self.build()
  75. def build(self) -> None:
  76. self.adj_matrices = torch.nn.ParameterDict()
  77. for _, et in self.data.edge_types.items():
  78. adj_matrices = [
  79. norm_adj_mat_one_node_type(x) \
  80. if et.vertex_type_row == et.vertex_type_column \
  81. else norm_adj_mat_two_node_types(x) \
  82. for x in et.adjacency_matrices
  83. ]
  84. adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
  85. print('adj_matrices:', adj_matrices)
  86. self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
  87. torch.nn.Parameter(adj_matrices, requires_grad=False)
  88. self.conv_weights = torch.nn.ParameterDict()
  89. for i in range(len(self.layer_dimensions) - 1):
  90. in_dimension = self.layer_dimensions[i]
  91. out_dimension = self.layer_dimensions[i + 1]
  92. for _, et in self.data.edge_types.items():
  93. weights = [ init_glorot(in_dimension, out_dimension) \
  94. for _ in range(len(et.adjacency_matrices)) ]
  95. weights = torch.cat(weights, dim=1)
  96. self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  97. torch.nn.Parameter(weights)
  98. self.dec_weights = torch.nn.ParameterDict()
  99. for _, et in self.data.edge_types.items():
  100. global_interaction, local_variation = \
  101. et.decoder_factory(self.layer_dimensions[-1],
  102. len(et.adjacency_matrices))
  103. self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
  104. torch.nn.Parameter(global_interaction)
  105. for i in range(len(local_variation)):
  106. self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  107. torch.nn.Parameter(local_variation[i])
  108. def convolve(self, in_layer_repr: List[torch.Tensor]) -> \
  109. List[torch.Tensor]:
  110. cur_layer_repr = in_layer_repr
  111. next_layer_repr = [ None ] * len(self.data.vertex_types)
  112. for i in range(len(self.layer_dimensions) - 1):
  113. for _, et in self.data.edge_types.items():
  114. vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
  115. adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
  116. conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
  117. num_relation_types = len(et.adjacency_matrices)
  118. x = cur_layer_repr[vt_col]
  119. if self.keep_prob != 1:
  120. x = dropout(x, self.keep_prob)
  121. print('a, Layer:', i, 'x.shape:', x.shape)
  122. x = _mm(x, conv_weights)
  123. x = torch.split(x,
  124. x.shape[1] // num_relation_types,
  125. dim=1)
  126. x = torch.cat(x)
  127. x = _mm(adj_matrices, x)
  128. x = x.view(num_relation_types,
  129. self.data.vertex_types[vt_row].count,
  130. self.layer_dimensions[i + 1])
  131. print('b, Layer:', i, 'x.shape:', x.shape)
  132. x = x.sum(dim=0)
  133. x = torch.nn.functional.normalize(x, p=2, dim=1)
  134. x = self.conv_activation(x)
  135. next_layer_repr[vt_row] = x
  136. cur_layer_repr = next_layer_repr
  137. return next_layer_repr
  138. def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
  139. edges = []
  140. cur_edges = batch.edges
  141. for _ in range(len(self.layer_dimensions) - 1):
  142. edges.append(cur_edges)
  143. key = (batch.vertex_type_row, batch.vertex_type_column)
  144. tot_conn = self.data.relation_types[key].total_connectivity
  145. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  146. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  147. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  148. col = batch.vertex_type_column
  149. rows = batch.edges[:, 0]
  150. columns = batch.edges[:, 1].sum(dim=0).flatten()
  151. columns = torch.nonzero(columns)
  152. for i in range(len(self.layer_dimensions) - 1):
  153. pass # columns =
  154. # TODO: finish
  155. return None
  156. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  157. col = batch.vertex_type_column
  158. batch.edges[:, 1]
  159. res = {}
  160. for _, et in self.data.edge_types.items():
  161. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  162. res[et.vertex_type_row, et.vertex_type_column] = \
  163. [ self.temporary_adjacency_matrix(adj_mat, batch,
  164. et.total_connectivity) \
  165. for adj_mat in et.adjacency_matrices ]
  166. return res
  167. def forward(self, initial_repr: List[torch.Tensor],
  168. batch: TrainingBatch) -> torch.Tensor:
  169. if not isinstance(initial_repr, list):
  170. raise TypeError('initial_repr must be a list')
  171. if len(initial_repr) != len(self.data.vertex_types):
  172. raise ValueError('initial_repr must contain representations for all vertex types')
  173. if not isinstance(batch, TrainingBatch):
  174. raise TypeError('batch must be an instance of TrainingBatch')
  175. adj_matrices = self.temporary_adjacency_matrices(batch)
  176. row_vertices = initial_repr[batch.vertex_type_row]
  177. column_vertices = initial_repr[batch.vertex_type_column]