IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
8.4KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor, \
  12. _sparse_diag_cat, \
  13. _mm
  14. from .normalize import norm_adj_mat_one_node_type, \
  15. norm_adj_mat_two_node_types
  16. from .dropout import dropout
  17. @dataclass
  18. class TrainingBatch(object):
  19. vertex_type_row: int
  20. vertex_type_column: int
  21. relation_type_index: int
  22. edges: torch.Tensor
  23. def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
  24. num_layers: int) -> List[List[EdgeType]]:
  25. Q = [
  26. ( batch.vertex_type_row, batch.edges[:, 0] ),
  27. ( batch.vertex_type_column, batch.edges[:, 1] )
  28. ]
  29. print('Q:', Q)
  30. res = []
  31. for _ in range(num_layers):
  32. R = []
  33. required_rows = [ [] for _ in range(len(data.vertex_types)) ]
  34. for vertex_type, vertices in Q:
  35. for et in data.edge_types.values():
  36. if et.vertex_type_row == vertex_type:
  37. required_rows[vertex_type].append(vertices)
  38. indices = et.total_connectivity.indices()
  39. mask = torch.zeros(et.total_connectivity.shape[0])
  40. mask[vertices] = 1
  41. mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
  42. R.append((et.vertex_type_column,
  43. indices[1, mask]))
  44. else:
  45. pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
  46. required_rows = [ torch.unique(torch.cat(x)) \
  47. if len(x) > 0 \
  48. else None \
  49. for x in required_rows ]
  50. res.append(required_rows)
  51. Q = R
  52. return res
  53. class Model(torch.nn.Module):
  54. def __init__(self, data: Data, layer_dimensions: List[int],
  55. keep_prob: float,
  56. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  57. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  58. **kwargs) -> None:
  59. super().__init__(**kwargs)
  60. if not isinstance(data, Data):
  61. raise TypeError('data must be an instance of Data')
  62. if not callable(conv_activation):
  63. raise TypeError('conv_activation must be callable')
  64. if not callable(dec_activation):
  65. raise TypeError('dec_activation must be callable')
  66. self.data = data
  67. self.layer_dimensions = list(layer_dimensions)
  68. self.keep_prob = float(keep_prob)
  69. self.conv_activation = conv_activation
  70. self.dec_activation = dec_activation
  71. self.adj_matrices = None
  72. self.conv_weights = None
  73. self.dec_weights = None
  74. self.build()
  75. def build(self) -> None:
  76. self.adj_matrices = torch.nn.ParameterDict()
  77. for _, et in self.data.edge_types.items():
  78. adj_matrices = [
  79. norm_adj_mat_one_node_type(x) \
  80. if et.vertex_type_row == et.vertex_type_column \
  81. else norm_adj_mat_two_node_types(x) \
  82. for x in et.adjacency_matrices
  83. ]
  84. adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
  85. print('adj_matrices:', adj_matrices)
  86. self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
  87. torch.nn.Parameter(adj_matrices, requires_grad=False)
  88. self.conv_weights = torch.nn.ParameterDict()
  89. for i in range(len(self.layer_dimensions) - 1):
  90. in_dimension = self.layer_dimensions[i]
  91. out_dimension = self.layer_dimensions[i + 1]
  92. for _, et in self.data.edge_types.items():
  93. weights = [ init_glorot(in_dimension, out_dimension) \
  94. for _ in range(len(et.adjacency_matrices)) ]
  95. weights = torch.cat(weights, dim=1)
  96. self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  97. torch.nn.Parameter(weights)
  98. self.dec_weights = torch.nn.ParameterDict()
  99. for _, et in self.data.edge_types.items():
  100. global_interaction, local_variation = \
  101. et.decoder_factory(self.layer_dimensions[-1],
  102. len(et.adjacency_matrices))
  103. self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
  104. torch.nn.Parameter(global_interaction)
  105. for i in range(len(local_variation)):
  106. self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
  107. torch.nn.Parameter(local_variation[i])
  108. def convolve(self, in_layer_repr: List[torch.Tensor]) -> \
  109. List[torch.Tensor]:
  110. cur_layer_repr = in_layer_repr
  111. next_layer_repr = [ None ] * len(self.data.vertex_types)
  112. for i in range(len(self.layer_dimensions) - 1):
  113. for _, et in self.data.edge_types.items():
  114. vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
  115. adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
  116. conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
  117. num_relation_types = len(et.adjacency_matrices)
  118. x = cur_layer_repr[vt_col]
  119. if self.keep_prob != 1:
  120. x = dropout(x, self.keep_prob)
  121. print('a, Layer:', i, 'x.shape:', x.shape)
  122. x = _mm(x, conv_weights)
  123. x = torch.split(x,
  124. x.shape[1] // num_relation_types,
  125. dim=1)
  126. x = torch.cat(x)
  127. x = _mm(adj_matrices, x)
  128. x = x.view(num_relation_types,
  129. self.data.vertex_types[vt_row].count,
  130. self.layer_dimensions[i + 1])
  131. print('b, Layer:', i, 'x.shape:', x.shape)
  132. x = x.sum(dim=0)
  133. x = torch.nn.functional.normalize(x, p=2, dim=1)
  134. x = self.conv_activation(x)
  135. next_layer_repr[vt_row] = x
  136. cur_layer_repr = next_layer_repr
  137. return next_layer_repr
  138. def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
  139. edges = []
  140. cur_edges = batch.edges
  141. for _ in range(len(self.layer_dimensions) - 1):
  142. edges.append(cur_edges)
  143. key = (batch.vertex_type_row, batch.vertex_type_column)
  144. tot_conn = self.data.relation_types[key].total_connectivity
  145. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  146. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  147. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  148. col = batch.vertex_type_column
  149. rows = batch.edges[:, 0]
  150. columns = batch.edges[:, 1].sum(dim=0).flatten()
  151. columns = torch.nonzero(columns)
  152. for i in range(len(self.layer_dimensions) - 1):
  153. pass # columns =
  154. # TODO: finish
  155. return None
  156. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  157. col = batch.vertex_type_column
  158. batch.edges[:, 1]
  159. res = {}
  160. for _, et in self.data.edge_types.items():
  161. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  162. res[et.vertex_type_row, et.vertex_type_column] = \
  163. [ self.temporary_adjacency_matrix(adj_mat, batch,
  164. et.total_connectivity) \
  165. for adj_mat in et.adjacency_matrices ]
  166. return res
  167. def forward(self, initial_repr: List[torch.Tensor],
  168. batch: TrainingBatch) -> torch.Tensor:
  169. if not isinstance(initial_repr, list):
  170. raise TypeError('initial_repr must be a list')
  171. if len(initial_repr) != len(self.data.vertex_types):
  172. raise ValueError('initial_repr must contain representations for all vertex types')
  173. if not isinstance(batch, TrainingBatch):
  174. raise TypeError('batch must be an instance of TrainingBatch')
  175. adj_matrices = self.temporary_adjacency_matrices(batch)
  176. row_vertices = initial_repr[batch.vertex_type_row]
  177. column_vertices = initial_repr[batch.vertex_type_column]