IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

trainprep.py 6.9KB


  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass, \
  8. field
  9. from typing import Any, \
  10. List, \
  11. Tuple, \
  12. Dict
  13. from .data import NodeType, \
  14. RelationType, \
  15. RelationTypeBase, \
  16. RelationFamily, \
  17. RelationFamilyBase, \
  18. Data
  19. from collections import defaultdict
  20. from .normalize import norm_adj_mat_one_node_type, \
  21. norm_adj_mat_two_node_types
  22. import numpy as np
  23. @dataclass
  24. class TrainValTest(object):
  25. train: Any
  26. val: Any
  27. test: Any
  28. @dataclass
  29. class PreparedRelationType(RelationTypeBase):
  30. edges_pos: TrainValTest
  31. edges_neg: TrainValTest
  32. edges_back_pos: TrainValTest
  33. edges_back_neg: TrainValTest
  34. @dataclass
  35. class PreparedRelationFamily(RelationFamilyBase):
  36. relation_types: List[PreparedRelationType]
  37. @dataclass
  38. class PreparedData(object):
  39. node_types: List[NodeType]
  40. relation_families: List[PreparedRelationFamily]
  41. def _empty_edge_list_tvt() -> TrainValTest:
  42. return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ])
  43. def train_val_test_split_edges(edges: torch.Tensor,
  44. ratios: TrainValTest) -> TrainValTest:
  45. if not isinstance(edges, torch.Tensor):
  46. raise ValueError('edges must be a torch.Tensor')
  47. if len(edges.shape) != 2 or edges.shape[1] != 2:
  48. raise ValueError('edges shape must be (num_edges, 2)')
  49. if not isinstance(ratios, TrainValTest):
  50. raise ValueError('ratios must be a TrainValTest')
  51. if ratios.train + ratios.val + ratios.test != 1.0:
  52. raise ValueError('Train, validation and test ratios must add up to 1')
  53. order = torch.randperm(len(edges))
  54. edges = edges[order, :]
  55. n = round(len(edges) * ratios.train)
  56. edges_train = edges[:n]
  57. n_1 = round(len(edges) * (ratios.train + ratios.val))
  58. edges_val = edges[n:n_1]
  59. edges_test = edges[n_1:]
  60. return TrainValTest(edges_train, edges_val, edges_test)
  61. def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  62. if adj_mat.is_sparse:
  63. adj_mat = adj_mat.coalesce()
  64. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64)
  65. degrees = degrees.index_add(0, adj_mat.indices()[1],
  66. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64))
  67. edges_pos = adj_mat.indices().transpose(0, 1)
  68. else:
  69. degrees = adj_mat.sum(0)
  70. edges_pos = torch.nonzero(adj_mat)
  71. return edges_pos, degrees
  72. def prepare_adj_mat(adj_mat: torch.Tensor,
  73. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  74. if not isinstance(adj_mat, torch.Tensor):
  75. raise ValueError('adj_mat must be a torch.Tensor')
  76. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  77. neg_neighbors = fixed_unigram_candidate_sampler(
  78. edges_pos[:, 1].view(-1, 1), degrees, 0.75)
  79. print(edges_pos.dtype)
  80. print(neg_neighbors.dtype)
  81. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  82. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  83. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  84. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  85. values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype)
  86. return adj_mat_train, edges_pos, edges_neg
  87. def prep_rel_one_node_type(r: RelationType,
  88. ratios: TrainValTest) -> PreparedRelationType:
  89. adj_mat = r.adjacency_matrix
  90. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  91. adj_mat_back_train, edges_back_pos, edges_back_neg = \
  92. None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
  93. print('adj_mat_train:', adj_mat_train)
  94. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  95. return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
  96. adj_mat_train, adj_mat_back_train, edges_pos, edges_neg,
  97. edges_back_pos, edges_back_neg)
  98. def prep_rel_two_node_types_sym(r: RelationType,
  99. ratios: TrainValTest) -> PreparedRelationType:
  100. adj_mat = r.adjacency_matrix
  101. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  102. edges_back_pos, edges_back_neg = \
  103. _empty_edge_list_tvt(), _empty_edge_list_tvt()
  104. return PreparedRelationType(r.name, r.node_type_row,
  105. r.node_type_column,
  106. norm_adj_mat_two_node_types(adj_mat_train),
  107. norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)),
  108. edges_pos, edges_neg, edges_back_pos, edges_back_neg)
  109. def prep_rel_two_node_types_asym(r: RelationType,
  110. ratios: TrainValTest) -> PreparedRelationType:
  111. if r.adjacency_matrix is not None:
  112. adj_mat_train, edges_pos, edges_neg =\
  113. prepare_adj_mat(r.adjacency_matrix, ratios)
  114. else:
  115. adj_mat_train, edges_pos, edges_neg = \
  116. None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
  117. if r.adjacency_matrix_backward is not None:
  118. adj_mat_back_train, edges_back_pos, edges_back_neg = \
  119. prepare_adj_mat(r.adjacency_matrix_backward, ratios)
  120. else:
  121. adj_mat_back_train, edges_back_pos, edges_back_neg = \
  122. None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
  123. return PreparedRelationType(r.name, r.node_type_row,
  124. r.node_type_column,
  125. norm_adj_mat_two_node_types(adj_mat_train),
  126. norm_adj_mat_two_node_types(adj_mat_back_train),
  127. edges_pos, edges_neg, edges_back_pos, edges_back_neg)
  128. def prepare_relation_type(r: RelationType,
  129. ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType:
  130. if not isinstance(r, RelationType):
  131. raise ValueError('r must be a RelationType')
  132. if not isinstance(ratios, TrainValTest):
  133. raise ValueError('ratios must be a TrainValTest')
  134. if r.node_type_row == r.node_type_column:
  135. return prep_rel_one_node_type(r, ratios)
  136. elif is_symmetric:
  137. return prep_rel_two_node_types_sym(r, ratios)
  138. else:
  139. return prep_rel_two_node_types_asym(r, ratios)
  140. def prepare_relation_family(fam: RelationFamily,
  141. ratios: TrainValTest) -> PreparedRelationFamily:
  142. relation_types = []
  143. for r in fam.relation_types:
  144. relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric))
  145. return PreparedRelationFamily(fam.data, fam.name,
  146. fam.node_type_row, fam.node_type_column,
  147. fam.is_symmetric, fam.decoder_class,
  148. relation_types)
  149. def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
  150. if not isinstance(data, Data):
  151. raise ValueError('data must be of class Data')
  152. relation_families = [ prepare_relation_family(fam, ratios) \
  153. for fam in data.relation_families ]
  154. return PreparedData(data.node_types, relation_families)