IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

trainprep.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass, \
  8. field
  9. from typing import Any, \
  10. List, \
  11. Tuple, \
  12. Dict
  13. from .data import NodeType, \
  14. RelationType, \
  15. RelationTypeBase, \
  16. RelationFamily, \
  17. RelationFamilyBase, \
  18. Data
  19. from collections import defaultdict
  20. from .normalize import norm_adj_mat_one_node_type, \
  21. norm_adj_mat_two_node_types
  22. import numpy as np
  23. @dataclass
  24. class TrainValTest(object):
  25. train: Any
  26. val: Any
  27. test: Any
  28. @dataclass
  29. class PreparedRelationType(RelationTypeBase):
  30. edges_pos: TrainValTest
  31. edges_neg: TrainValTest
  32. edges_back_pos: TrainValTest
  33. edges_back_neg: TrainValTest
  34. @dataclass
  35. class PreparedRelationFamily(RelationFamilyBase):
  36. relation_types: List[PreparedRelationType]
  37. @dataclass
  38. class PreparedData(object):
  39. node_types: List[NodeType]
  40. relation_families: List[PreparedRelationFamily]
  41. def _empty_edge_list_tvt() -> TrainValTest:
  42. return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ])
  43. def train_val_test_split_edges(edges: torch.Tensor,
  44. ratios: TrainValTest) -> TrainValTest:
  45. if not isinstance(edges, torch.Tensor):
  46. raise ValueError('edges must be a torch.Tensor')
  47. if len(edges.shape) != 2 or edges.shape[1] != 2:
  48. raise ValueError('edges shape must be (num_edges, 2)')
  49. if not isinstance(ratios, TrainValTest):
  50. raise ValueError('ratios must be a TrainValTest')
  51. if ratios.train + ratios.val + ratios.test != 1.0:
  52. raise ValueError('Train, validation and test ratios must add up to 1')
  53. order = torch.randperm(len(edges))
  54. edges = edges[order, :]
  55. n = round(len(edges) * ratios.train)
  56. edges_train = edges[:n]
  57. n_1 = round(len(edges) * (ratios.train + ratios.val))
  58. edges_val = edges[n:n_1]
  59. edges_test = edges[n_1:]
  60. return TrainValTest(edges_train, edges_val, edges_test)
  61. def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  62. if adj_mat.is_sparse:
  63. adj_mat = adj_mat.coalesce()
  64. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
  65. device=adj_mat.device)
  66. degrees = degrees.index_add(0, adj_mat.indices()[1],
  67. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
  68. device=adj_mat.device))
  69. edges_pos = adj_mat.indices().transpose(0, 1)
  70. else:
  71. degrees = adj_mat.sum(0)
  72. edges_pos = torch.nonzero(adj_mat)
  73. return edges_pos, degrees
  74. def prepare_adj_mat(adj_mat: torch.Tensor,
  75. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  76. if not isinstance(adj_mat, torch.Tensor):
  77. raise ValueError('adj_mat must be a torch.Tensor')
  78. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  79. neg_neighbors = fixed_unigram_candidate_sampler(
  80. edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device)
  81. print(edges_pos.dtype)
  82. print(neg_neighbors.dtype)
  83. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  84. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  85. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  86. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  87. values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype,
  88. device=adj_mat.device)
  89. return adj_mat_train, edges_pos, edges_neg
  90. def prep_rel_one_node_type(r: RelationType,
  91. ratios: TrainValTest) -> PreparedRelationType:
  92. adj_mat = r.adjacency_matrix
  93. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  94. adj_mat_back_train, edges_back_pos, edges_back_neg = \
  95. None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
  96. print('adj_mat_train:', adj_mat_train)
  97. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  98. return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
  99. adj_mat_train, adj_mat_back_train, edges_pos, edges_neg,
  100. edges_back_pos, edges_back_neg)
  101. def prep_rel_two_node_types_sym(r: RelationType,
  102. ratios: TrainValTest) -> PreparedRelationType:
  103. adj_mat = r.adjacency_matrix
  104. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  105. edges_back_pos, edges_back_neg = \
  106. _empty_edge_list_tvt(), _empty_edge_list_tvt()
  107. return PreparedRelationType(r.name, r.node_type_row,
  108. r.node_type_column,
  109. norm_adj_mat_two_node_types(adj_mat_train),
  110. norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)),
  111. edges_pos, edges_neg, edges_back_pos, edges_back_neg)
  112. def prep_rel_two_node_types_asym(r: RelationType,
  113. ratios: TrainValTest) -> PreparedRelationType:
  114. if r.adjacency_matrix is not None:
  115. adj_mat_train, edges_pos, edges_neg =\
  116. prepare_adj_mat(r.adjacency_matrix, ratios)
  117. else:
  118. adj_mat_train, edges_pos, edges_neg = \
  119. None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
  120. if r.adjacency_matrix_backward is not None:
  121. adj_mat_back_train, edges_back_pos, edges_back_neg = \
  122. prepare_adj_mat(r.adjacency_matrix_backward, ratios)
  123. else:
  124. adj_mat_back_train, edges_back_pos, edges_back_neg = \
  125. None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
  126. return PreparedRelationType(r.name, r.node_type_row,
  127. r.node_type_column,
  128. norm_adj_mat_two_node_types(adj_mat_train),
  129. norm_adj_mat_two_node_types(adj_mat_back_train),
  130. edges_pos, edges_neg, edges_back_pos, edges_back_neg)
  131. def prepare_relation_type(r: RelationType,
  132. ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType:
  133. if not isinstance(r, RelationType):
  134. raise ValueError('r must be a RelationType')
  135. if not isinstance(ratios, TrainValTest):
  136. raise ValueError('ratios must be a TrainValTest')
  137. if r.node_type_row == r.node_type_column:
  138. return prep_rel_one_node_type(r, ratios)
  139. elif is_symmetric:
  140. return prep_rel_two_node_types_sym(r, ratios)
  141. else:
  142. return prep_rel_two_node_types_asym(r, ratios)
  143. def prepare_relation_family(fam: RelationFamily,
  144. ratios: TrainValTest) -> PreparedRelationFamily:
  145. relation_types = []
  146. for r in fam.relation_types:
  147. relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric))
  148. return PreparedRelationFamily(fam.data, fam.name,
  149. fam.node_type_row, fam.node_type_column,
  150. fam.is_symmetric, fam.decoder_class,
  151. relation_types)
  152. def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
  153. if not isinstance(data, Data):
  154. raise ValueError('data must be of class Data')
  155. relation_families = [ prepare_relation_family(fam, ratios) \
  156. for fam in data.relation_families ]
  157. return PreparedData(data.node_types, relation_families)