IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

277 lines
10KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor, \
  12. _sparse_diag_cat, \
  13. _mm
  14. from .normalize import norm_adj_mat_one_node_type, \
  15. norm_adj_mat_two_node_types
  16. from .dropout import dropout
  17. @dataclass
  18. class TrainingBatch(object):
  19. vertex_type_row: int
  20. vertex_type_column: int
  21. relation_type_index: int
  22. edges: torch.Tensor
  23. target_values: torch.Tensor
  24. def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
  25. num_layers: int) -> List[List[EdgeType]]:
  26. Q = [
  27. ( batch.vertex_type_row, batch.edges[:, 0] ),
  28. ( batch.vertex_type_column, batch.edges[:, 1] )
  29. ]
  30. print('Q:', Q)
  31. res = []
  32. for _ in range(num_layers):
  33. R = []
  34. required_rows = [ [] for _ in range(len(data.vertex_types)) ]
  35. for vertex_type, vertices in Q:
  36. for et in data.edge_types.values():
  37. if et.vertex_type_row == vertex_type:
  38. required_rows[vertex_type].append(vertices)
  39. indices = et.total_connectivity.indices()
  40. mask = torch.zeros(et.total_connectivity.shape[0])
  41. mask[vertices] = 1
  42. mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
  43. R.append((et.vertex_type_column,
  44. indices[1, mask]))
  45. else:
  46. pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
  47. required_rows = [ torch.unique(torch.cat(x)) \
  48. if len(x) > 0 \
  49. else None \
  50. for x in required_rows ]
  51. res.append(required_rows)
  52. Q = R
  53. return res
  54. class Model(torch.nn.Module):
  55. def __init__(self, data: Data, layer_dimensions: List[int],
  56. keep_prob: float,
  57. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  58. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  59. **kwargs) -> None:
  60. super().__init__(**kwargs)
  61. if not isinstance(data, Data):
  62. raise TypeError('data must be an instance of Data')
  63. if not callable(conv_activation):
  64. raise TypeError('conv_activation must be callable')
  65. if not callable(dec_activation):
  66. raise TypeError('dec_activation must be callable')
  67. self.data = data
  68. self.layer_dimensions = list(layer_dimensions)
  69. self.keep_prob = float(keep_prob)
  70. self.conv_activation = conv_activation
  71. self.dec_activation = dec_activation
  72. self.adj_matrices = None
  73. self.conv_weights = None
  74. self.dec_weights = None
  75. self.build()
  76. def build(self) -> None:
  77. self.adj_matrices = torch.nn.ParameterDict()
  78. for _, et in self.data.edge_types.items():
  79. adj_matrices = [
  80. norm_adj_mat_one_node_type(x) \
  81. if et.vertex_type_row == et.vertex_type_column \
  82. else norm_adj_mat_two_node_types(x) \
  83. for x in et.adjacency_matrices
  84. ]
  85. adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
  86. print('adj_matrices:', adj_matrices)
  87. self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
  88. torch.nn.Parameter(adj_matrices, requires_grad=False)
  89. self.conv_weights = torch.nn.ParameterDict()
  90. for i in range(len(self.layer_dimensions) - 1):
  91. in_dimension = self.layer_dimensions[i]
  92. out_dimension = self.layer_dimensions[i + 1]
  93. for _, et in self.data.edge_types.items():
  94. weights = [ init_glorot(in_dimension, out_dimension) \
  95. for _ in range(len(et.adjacency_matrices)) ]
  96. weights = torch.cat(weights, dim=1)
  97. self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  98. torch.nn.Parameter(weights)
  99. self.dec_weights = torch.nn.ParameterDict()
  100. for _, et in self.data.edge_types.items():
  101. global_interaction, local_variation = \
  102. et.decoder_factory(self.layer_dimensions[-1],
  103. len(et.adjacency_matrices))
  104. self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
  105. torch.nn.Parameter(global_interaction)
  106. for i in range(len(local_variation)):
  107. self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  108. torch.nn.Parameter(local_variation[i])
  109. def convolve(self, in_layer_repr: List[torch.Tensor]) -> \
  110. List[torch.Tensor]:
  111. cur_layer_repr = in_layer_repr
  112. for i in range(len(self.layer_dimensions) - 1):
  113. next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ]
  114. for _, et in self.data.edge_types.items():
  115. vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
  116. adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
  117. conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
  118. num_relation_types = len(et.adjacency_matrices)
  119. x = cur_layer_repr[vt_col]
  120. if self.keep_prob != 1:
  121. x = dropout(x, self.keep_prob)
  122. # print('a, Layer:', i, 'x.shape:', x.shape)
  123. x = _mm(x, conv_weights)
  124. x = torch.split(x,
  125. x.shape[1] // num_relation_types,
  126. dim=1)
  127. x = torch.cat(x)
  128. x = _mm(adj_matrices, x)
  129. x = x.view(num_relation_types,
  130. self.data.vertex_types[vt_row].count,
  131. self.layer_dimensions[i + 1])
  132. # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
  133. x = x.sum(dim=0)
  134. x = torch.nn.functional.normalize(x, p=2, dim=1)
  135. # x = self.rel_activation(x)
  136. # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
  137. next_layer_repr[vt_row].append(x)
  138. next_layer_repr = [ self.conv_activation(sum(x)) \
  139. for x in next_layer_repr ]
  140. cur_layer_repr = next_layer_repr
  141. return cur_layer_repr
  142. def decode(self, last_layer_repr: List[torch.Tensor],
  143. batch: TrainingBatch) -> torch.Tensor:
  144. vt_row = batch.vertex_type_row
  145. vt_col = batch.vertex_type_column
  146. rel_idx = batch.relation_type_index
  147. global_interaction = \
  148. self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)]
  149. local_variation = \
  150. self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)]
  151. in_row = last_layer_repr[vt_row]
  152. in_col = last_layer_repr[vt_col]
  153. if in_row.is_sparse or in_col.is_sparse:
  154. raise ValueError('Inputs to Model.decode() must be dense')
  155. in_row = in_row[batch.edges[:, 0]]
  156. in_col = in_col[batch.edges[:, 1]]
  157. in_row = dropout(in_row, self.keep_prob)
  158. in_col = dropout(in_col, self.keep_prob)
  159. # in_row = in_row.to_dense()
  160. # in_col = in_col.to_dense()
  161. print('in_row.is_sparse:', in_row.is_sparse)
  162. print('in_col.is_sparse:', in_col.is_sparse)
  163. x = torch.mm(in_row, local_variation)
  164. x = torch.mm(x, global_interaction)
  165. x = torch.mm(x, local_variation)
  166. x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]),
  167. in_col.view(in_col.shape[0], in_col.shape[1], 1))
  168. x = torch.flatten(x)
  169. x = self.dec_activation(x)
  170. return x
  171. def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
  172. edges = []
  173. cur_edges = batch.edges
  174. for _ in range(len(self.layer_dimensions) - 1):
  175. edges.append(cur_edges)
  176. key = (batch.vertex_type_row, batch.vertex_type_column)
  177. tot_conn = self.data.relation_types[key].total_connectivity
  178. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  179. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  180. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  181. col = batch.vertex_type_column
  182. rows = batch.edges[:, 0]
  183. columns = batch.edges[:, 1].sum(dim=0).flatten()
  184. columns = torch.nonzero(columns)
  185. for i in range(len(self.layer_dimensions) - 1):
  186. pass # columns =
  187. # TODO: finish
  188. return None
  189. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  190. col = batch.vertex_type_column
  191. batch.edges[:, 1]
  192. res = {}
  193. for _, et in self.data.edge_types.items():
  194. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  195. res[et.vertex_type_row, et.vertex_type_column] = \
  196. [ self.temporary_adjacency_matrix(adj_mat, batch,
  197. et.total_connectivity) \
  198. for adj_mat in et.adjacency_matrices ]
  199. return res
  200. def forward(self, initial_repr: List[torch.Tensor],
  201. batch: TrainingBatch) -> torch.Tensor:
  202. if not isinstance(initial_repr, list):
  203. raise TypeError('initial_repr must be a list')
  204. if len(initial_repr) != len(self.data.vertex_types):
  205. raise ValueError('initial_repr must contain representations for all vertex types')
  206. if not isinstance(batch, TrainingBatch):
  207. raise TypeError('batch must be an instance of TrainingBatch')
  208. adj_matrices = self.temporary_adjacency_matrices(batch)
  209. row_vertices = initial_repr[batch.vertex_type_row]
  210. column_vertices = initial_repr[batch.vertex_type_column]