IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

277 Zeilen
10KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor, \
  12. _sparse_diag_cat, \
  13. _mm
  14. from .normalize import norm_adj_mat_one_node_type, \
  15. norm_adj_mat_two_node_types
  16. from .dropout import dropout
  17. @dataclass
  18. class TrainingBatch(object):
  19. vertex_type_row: int
  20. vertex_type_column: int
  21. relation_type_index: int
  22. edges: torch.Tensor
  23. target_values: torch.Tensor
  24. def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
  25. num_layers: int) -> List[List[EdgeType]]:
  26. Q = [
  27. ( batch.vertex_type_row, batch.edges[:, 0] ),
  28. ( batch.vertex_type_column, batch.edges[:, 1] )
  29. ]
  30. print('Q:', Q)
  31. res = []
  32. for _ in range(num_layers):
  33. R = []
  34. required_rows = [ [] for _ in range(len(data.vertex_types)) ]
  35. for vertex_type, vertices in Q:
  36. for et in data.edge_types.values():
  37. if et.vertex_type_row == vertex_type:
  38. required_rows[vertex_type].append(vertices)
  39. indices = et.total_connectivity.indices()
  40. mask = torch.zeros(et.total_connectivity.shape[0])
  41. mask[vertices] = 1
  42. mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
  43. R.append((et.vertex_type_column,
  44. indices[1, mask]))
  45. else:
  46. pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
  47. required_rows = [ torch.unique(torch.cat(x)) \
  48. if len(x) > 0 \
  49. else None \
  50. for x in required_rows ]
  51. res.append(required_rows)
  52. Q = R
  53. return res
  54. class Model(torch.nn.Module):
  55. def __init__(self, data: Data, layer_dimensions: List[int],
  56. keep_prob: float,
  57. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  58. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  59. **kwargs) -> None:
  60. super().__init__(**kwargs)
  61. if not isinstance(data, Data):
  62. raise TypeError('data must be an instance of Data')
  63. if not callable(conv_activation):
  64. raise TypeError('conv_activation must be callable')
  65. if not callable(dec_activation):
  66. raise TypeError('dec_activation must be callable')
  67. self.data = data
  68. self.layer_dimensions = list(layer_dimensions)
  69. self.keep_prob = float(keep_prob)
  70. self.conv_activation = conv_activation
  71. self.dec_activation = dec_activation
  72. self.adj_matrices = None
  73. self.conv_weights = None
  74. self.dec_weights = None
  75. self.build()
  76. def build(self) -> None:
  77. self.adj_matrices = torch.nn.ParameterDict()
  78. for _, et in self.data.edge_types.items():
  79. adj_matrices = [
  80. norm_adj_mat_one_node_type(x) \
  81. if et.vertex_type_row == et.vertex_type_column \
  82. else norm_adj_mat_two_node_types(x) \
  83. for x in et.adjacency_matrices
  84. ]
  85. adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
  86. print('adj_matrices:', adj_matrices)
  87. self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
  88. torch.nn.Parameter(adj_matrices, requires_grad=False)
  89. self.conv_weights = torch.nn.ParameterDict()
  90. for i in range(len(self.layer_dimensions) - 1):
  91. in_dimension = self.layer_dimensions[i]
  92. out_dimension = self.layer_dimensions[i + 1]
  93. for _, et in self.data.edge_types.items():
  94. weights = [ init_glorot(in_dimension, out_dimension) \
  95. for _ in range(len(et.adjacency_matrices)) ]
  96. weights = torch.cat(weights, dim=1)
  97. self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  98. torch.nn.Parameter(weights)
  99. self.dec_weights = torch.nn.ParameterDict()
  100. for _, et in self.data.edge_types.items():
  101. global_interaction, local_variation = \
  102. et.decoder_factory(self.layer_dimensions[-1],
  103. len(et.adjacency_matrices))
  104. self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
  105. torch.nn.Parameter(global_interaction)
  106. for i in range(len(local_variation)):
  107. self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  108. torch.nn.Parameter(local_variation[i])
  109. def convolve(self, in_layer_repr: List[torch.Tensor]) -> \
  110. List[torch.Tensor]:
  111. cur_layer_repr = in_layer_repr
  112. for i in range(len(self.layer_dimensions) - 1):
  113. next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ]
  114. for _, et in self.data.edge_types.items():
  115. vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
  116. adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
  117. conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
  118. num_relation_types = len(et.adjacency_matrices)
  119. x = cur_layer_repr[vt_col]
  120. if self.keep_prob != 1:
  121. x = dropout(x, self.keep_prob)
  122. # print('a, Layer:', i, 'x.shape:', x.shape)
  123. x = _mm(x, conv_weights)
  124. x = torch.split(x,
  125. x.shape[1] // num_relation_types,
  126. dim=1)
  127. x = torch.cat(x)
  128. x = _mm(adj_matrices, x)
  129. x = x.view(num_relation_types,
  130. self.data.vertex_types[vt_row].count,
  131. self.layer_dimensions[i + 1])
  132. # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
  133. x = x.sum(dim=0)
  134. x = torch.nn.functional.normalize(x, p=2, dim=1)
  135. # x = self.rel_activation(x)
  136. # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
  137. next_layer_repr[vt_row].append(x)
  138. next_layer_repr = [ self.conv_activation(sum(x)) \
  139. for x in next_layer_repr ]
  140. cur_layer_repr = next_layer_repr
  141. return cur_layer_repr
  142. def decode(self, last_layer_repr: List[torch.Tensor],
  143. batch: TrainingBatch) -> torch.Tensor:
  144. vt_row = batch.vertex_type_row
  145. vt_col = batch.vertex_type_column
  146. rel_idx = batch.relation_type_index
  147. global_interaction = \
  148. self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)]
  149. local_variation = \
  150. self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)]
  151. in_row = last_layer_repr[vt_row]
  152. in_col = last_layer_repr[vt_col]
  153. if in_row.is_sparse or in_col.is_sparse:
  154. raise ValueError('Inputs to Model.decode() must be dense')
  155. in_row = in_row[batch.edges[:, 0]]
  156. in_col = in_col[batch.edges[:, 1]]
  157. in_row = dropout(in_row, self.keep_prob)
  158. in_col = dropout(in_col, self.keep_prob)
  159. # in_row = in_row.to_dense()
  160. # in_col = in_col.to_dense()
  161. print('in_row.is_sparse:', in_row.is_sparse)
  162. print('in_col.is_sparse:', in_col.is_sparse)
  163. x = torch.mm(in_row, local_variation)
  164. x = torch.mm(x, global_interaction)
  165. x = torch.mm(x, local_variation)
  166. x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]),
  167. in_col.view(in_col.shape[0], in_col.shape[1], 1))
  168. x = torch.flatten(x)
  169. x = self.dec_activation(x)
  170. return x
  171. def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
  172. edges = []
  173. cur_edges = batch.edges
  174. for _ in range(len(self.layer_dimensions) - 1):
  175. edges.append(cur_edges)
  176. key = (batch.vertex_type_row, batch.vertex_type_column)
  177. tot_conn = self.data.relation_types[key].total_connectivity
  178. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  179. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  180. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  181. col = batch.vertex_type_column
  182. rows = batch.edges[:, 0]
  183. columns = batch.edges[:, 1].sum(dim=0).flatten()
  184. columns = torch.nonzero(columns)
  185. for i in range(len(self.layer_dimensions) - 1):
  186. pass # columns =
  187. # TODO: finish
  188. return None
  189. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  190. col = batch.vertex_type_column
  191. batch.edges[:, 1]
  192. res = {}
  193. for _, et in self.data.edge_types.items():
  194. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  195. res[et.vertex_type_row, et.vertex_type_column] = \
  196. [ self.temporary_adjacency_matrix(adj_mat, batch,
  197. et.total_connectivity) \
  198. for adj_mat in et.adjacency_matrices ]
  199. return res
  200. def forward(self, initial_repr: List[torch.Tensor],
  201. batch: TrainingBatch) -> torch.Tensor:
  202. if not isinstance(initial_repr, list):
  203. raise TypeError('initial_repr must be a list')
  204. if len(initial_repr) != len(self.data.vertex_types):
  205. raise ValueError('initial_repr must contain representations for all vertex types')
  206. if not isinstance(batch, TrainingBatch):
  207. raise TypeError('batch must be an instance of TrainingBatch')
  208. adj_matrices = self.temporary_adjacency_matrices(batch)
  209. row_vertices = initial_repr[batch.vertex_type_row]
  210. column_vertices = initial_repr[batch.vertex_type_column]