IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

158 lines
4.9KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass, \
  8. field
  9. from typing import Any, \
  10. List, \
  11. Tuple, \
  12. Dict
  13. from .data import NodeType, \
  14. RelationType, \
  15. RelationTypeBase, \
  16. RelationFamily, \
  17. RelationFamilyBase, \
  18. Data
  19. from collections import defaultdict
  20. from .normalize import norm_adj_mat_one_node_type, \
  21. norm_adj_mat_two_node_types
  22. import numpy as np
  23. @dataclass
  24. class TrainValTest(object):
  25. train: Any
  26. val: Any
  27. test: Any
  28. @dataclass
  29. class PreparedRelationType(RelationTypeBase):
  30. edges_pos: TrainValTest
  31. edges_neg: TrainValTest
  32. @dataclass
  33. class PreparedRelationFamily(RelationFamilyBase):
  34. relation_types: Dict[Tuple[int, int], List[PreparedRelationType]]
  35. @dataclass
  36. class PreparedData(object):
  37. node_types: List[NodeType]
  38. relation_families: List[PreparedRelationFamily]
  39. def train_val_test_split_edges(edges: torch.Tensor,
  40. ratios: TrainValTest) -> TrainValTest:
  41. if not isinstance(edges, torch.Tensor):
  42. raise ValueError('edges must be a torch.Tensor')
  43. if len(edges.shape) != 2 or edges.shape[1] != 2:
  44. raise ValueError('edges shape must be (num_edges, 2)')
  45. if not isinstance(ratios, TrainValTest):
  46. raise ValueError('ratios must be a TrainValTest')
  47. if ratios.train + ratios.val + ratios.test != 1.0:
  48. raise ValueError('Train, validation and test ratios must add up to 1')
  49. order = torch.randperm(len(edges))
  50. edges = edges[order, :]
  51. n = round(len(edges) * ratios.train)
  52. edges_train = edges[:n]
  53. n_1 = round(len(edges) * (ratios.train + ratios.val))
  54. edges_val = edges[n:n_1]
  55. edges_test = edges[n_1:]
  56. return TrainValTest(edges_train, edges_val, edges_test)
  57. def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  58. if adj_mat.is_sparse:
  59. adj_mat = adj_mat.coalesce()
  60. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64)
  61. degrees = degrees.index_add(0, adj_mat.indices()[1],
  62. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64))
  63. edges_pos = adj_mat.indices().transpose(0, 1)
  64. else:
  65. degrees = adj_mat.sum(0)
  66. edges_pos = torch.nonzero(adj_mat)
  67. return edges_pos, degrees
  68. def prepare_adj_mat(adj_mat: torch.Tensor,
  69. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  70. if not isinstance(adj_mat, torch.Tensor):
  71. raise ValueError('adj_mat must be a torch.Tensor')
  72. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  73. neg_neighbors = fixed_unigram_candidate_sampler(
  74. edges_pos[:, 1].view(-1, 1), degrees, 0.75)
  75. print(edges_pos.dtype)
  76. print(neg_neighbors.dtype)
  77. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  78. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  79. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  80. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  81. values=torch.ones(len(edges_pos.train), dtype=adj_mat.dtype))
  82. return adj_mat_train, edges_pos, edges_neg
  83. def prepare_relation_type(r: RelationType,
  84. ratios: TrainValTest) -> PreparedRelationType:
  85. if not isinstance(r, RelationType):
  86. raise ValueError('r must be a RelationType')
  87. if not isinstance(ratios, TrainValTest):
  88. raise ValueError('ratios must be a TrainValTest')
  89. adj_mat = r.adjacency_matrix
  90. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  91. print('adj_mat_train:', adj_mat_train)
  92. if r.node_type_row == r.node_type_column:
  93. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  94. else:
  95. adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  96. return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
  97. adj_mat_train, r.two_way, edges_pos, edges_neg)
  98. def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily:
  99. relation_types = { (fam.node_type_row, fam.node_type_column): [],
  100. (fam.node_type_column, fam.node_type_row): [] }
  101. for (node_type_row, node_type_column), rels in fam.relation_types.items():
  102. for r in rels:
  103. relation_types[node_type_row, node_type_column].append(
  104. prepare_relation_type(r, ratios))
  105. return PreparedRelationFamily(fam.data, fam.name,
  106. fam.node_type_row, fam.node_type_column,
  107. fam.is_symmetric, fam.decoder_class,
  108. relation_types)
  109. def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
  110. if not isinstance(data, Data):
  111. raise ValueError('data must be of class Data')
  112. relation_families = [ prepare_relation_family(fam) \
  113. for fam in data.relation_families ]
  114. return PreparedData(data.node_types, relation_families)