IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

171 Zeilen
5.8KB

  1. from .data import Data, \
  2. EdgeType
  3. import torch
  4. from dataclasses import dataclass
  5. from .weights import init_glorot
  6. import types
  7. from typing import List, \
  8. Dict, \
  9. Callable, \
  10. Tuple
  11. from .util import _sparse_coo_tensor
  12. @dataclass
  13. class TrainingBatch(object):
  14. vertex_type_row: int
  15. vertex_type_column: int
  16. relation_type_index: int
  17. edges: torch.Tensor
  18. def _per_layer_required_rows(data: Data, batch: TrainingBatch,
  19. num_layers: int) -> List[List[EdgeType]]:
  20. Q = [
  21. ( batch.vertex_type_row, batch.edges[:, 0] ),
  22. ( batch.vertex_type_column, batch.edges[:, 1] )
  23. ]
  24. print('Q:', Q)
  25. res = []
  26. for _ in range(num_layers):
  27. R = []
  28. required_rows = [ [] for _ in range(len(data.vertex_types)) ]
  29. for vertex_type, vertices in Q:
  30. for et in data.edge_types.values():
  31. if et.vertex_type_row == vertex_type:
  32. required_rows[vertex_type].append(vertices)
  33. indices = et.total_connectivity.indices()
  34. mask = torch.zeros(et.total_connectivity.shape[0])
  35. mask[vertices] = 1
  36. mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
  37. R.append((et.vertex_type_column,
  38. indices[1, mask]))
  39. else:
  40. pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
  41. required_rows = [ torch.unique(torch.cat(x)) \
  42. if len(x) > 0 \
  43. else None \
  44. for x in required_rows ]
  45. res.append(required_rows)
  46. Q = R
  47. return res
  48. class Model(torch.nn.Module):
  49. def __init__(self, data: Data, layer_dimensions: List[int],
  50. keep_prob: float,
  51. conv_activation: Callable[[torch.Tensor], torch.Tensor],
  52. dec_activation: Callable[[torch.Tensor], torch.Tensor],
  53. **kwargs) -> None:
  54. super().__init__(**kwargs)
  55. if not isinstance(data, Data):
  56. raise TypeError('data must be an instance of Data')
  57. if not isinstance(conv_activation, types.FunctionType):
  58. raise TypeError('conv_activation must be a function')
  59. if not isinstance(dec_activation, types.FunctionType):
  60. raise TypeError('dec_activation must be a function')
  61. self.data = data
  62. self.layer_dimensions = list(layer_dimensions)
  63. self.keep_prob = float(keep_prob)
  64. self.conv_activation = conv_activation
  65. self.dec_activation = dec_activation
  66. self.conv_weights = None
  67. self.dec_weights = None
  68. self.build()
  69. def build(self) -> None:
  70. self.conv_weights = torch.nn.ParameterDict()
  71. for i in range(len(self.layer_dimensions) - 1):
  72. in_dimension = self.layer_dimensions[i]
  73. out_dimension = self.layer_dimensions[i + 1]
  74. for _, et in self.data.edge_types.items():
  75. weight = init_glorot(in_dimension, out_dimension)
  76. self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \
  77. torch.nn.Parameter(weight)
  78. self.dec_weights = torch.nn.ParameterDict()
  79. for _, et in self.data.edge_types.items():
  80. global_interaction, local_variation = \
  81. et.decoder_factory(self.layer_dimensions[-1],
  82. len(et.adjacency_matrices))
  83. self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \
  84. torch.nn.ParameterList([
  85. torch.nn.Parameter(global_interaction),
  86. torch.nn.Parameter(local_variation)
  87. ])
  88. def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]:
  89. edges = []
  90. cur_edges = batch.edges
  91. for _ in range(len(self.layer_dimensions) - 1):
  92. edges.append(cur_edges)
  93. key = (batch.vertex_type_row, batch.vertex_type_column)
  94. tot_conn = self.data.relation_types[key].total_connectivity
  95. cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
  96. def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
  97. batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
  98. col = batch.vertex_type_column
  99. rows = batch.edges[:, 0]
  100. columns = batch.edges[:, 1].sum(dim=0).flatten()
  101. columns = torch.nonzero(columns)
  102. for i in range(len(self.layer_dimensions) - 1):
  103. pass # columns =
  104. # TODO: finish
  105. return None
  106. def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
  107. col = batch.vertex_type_column
  108. batch.edges[:, 1]
  109. res = {}
  110. for _, et in self.data.edge_types.items():
  111. sum_nonzero = _nonzero_sum(et.adjacency_matrices)
  112. res[et.vertex_type_row, et.vertex_type_column] = \
  113. [ self.temporary_adjacency_matrix(adj_mat, batch,
  114. et.total_connectivity) \
  115. for adj_mat in et.adjacency_matrices ]
  116. return res
  117. def forward(self, initial_repr: List[torch.Tensor],
  118. batch: TrainingBatch) -> torch.Tensor:
  119. if not isinstance(initial_repr, list):
  120. raise TypeError('initial_repr must be a list')
  121. if len(initial_repr) != len(self.data.vertex_types):
  122. raise ValueError('initial_repr must contain representations for all vertex types')
  123. if not isinstance(batch, TrainingBatch):
  124. raise TypeError('batch must be an instance of TrainingBatch')
  125. adj_matrices = self.temporary_adjacency_matrices(batch)
  126. row_vertices = initial_repr[batch.vertex_type_row]
  127. column_vertices = initial_repr[batch.vertex_type_column]