IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

107 lines
3.0KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass
  8. from typing import Any, \
  9. List, \
  10. Tuple, \
  11. Dict
  12. from .data import NodeType
  13. from collections import defaultdict
  14. @dataclass
  15. class TrainValTest(object):
  16. train: Any
  17. val: Any
  18. test: Any
  19. @dataclass
  20. class PreparedEdges(object):
  21. positive: TrainValTest
  22. negative: TrainValTest
  23. @dataclass
  24. class PreparedRelationType(object):
  25. name: str
  26. node_type_row: int
  27. node_type_column: int
  28. adj_mat_train: torch.Tensor
  29. edges_pos: TrainValTest
  30. edges_neg: TrainValTest
  31. @dataclass
  32. class PreparedData(object):
  33. node_types: List[NodeType]
  34. relation_types: Dict[int, Dict[int, List[PreparedRelationType]]]
  35. def train_val_test_split_edges(edges: torch.Tensor,
  36. ratios: TrainValTest) -> TrainValTest:
  37. if not isinstance(edges, torch.Tensor):
  38. raise ValueError('edges must be a torch.Tensor')
  39. if len(edges.shape) != 2 or edges.shape[1] != 2:
  40. raise ValueError('edges shape must be (num_edges, 2)')
  41. if not isinstance(ratios, TrainValTest):
  42. raise ValueError('ratios must be a TrainValTest')
  43. if ratios.train + ratios.val + ratios.test != 1.0:
  44. raise ValueError('Train, validation and test ratios must add up to 1')
  45. order = torch.randperm(len(edges))
  46. edges = edges[order, :]
  47. n = round(len(edges) * ratios.train)
  48. edges_train = edges[:n]
  49. n_1 = round(len(edges) * (ratios.train + ratios.val))
  50. edges_val = edges[n:n_1]
  51. edges_test = edges[n_1:]
  52. return TrainValTest(edges_train, edges_val, edges_test)
  53. def prepare_adj_mat(adj_mat: torch.Tensor,
  54. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  55. degrees = adj_mat.sum(0)
  56. edges_pos = torch.nonzero(adj_mat)
  57. neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1],
  58. len(edges), degrees, 0.75)
  59. edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1)
  60. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  61. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  62. return edges_pos, edges_neg
  63. def prepare_relation(r, ratios):
  64. adj_mat = r.adjacency_matrix
  65. edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  66. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos[0].transpose(0, 1),
  67. values=torch.ones(len(edges_pos[0]), dtype=adj_mat.dtype))
  68. return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  69. adj_mat_train, edges_pos, edges_neg)
  70. def prepare_training(data):
  71. relation_types = defaultdict(lambda: defaultdict(list))
  72. for (node_type_row, node_type_column), rels in data.relation_types:
  73. for r in rels:
  74. relation_types[node_type_row][node_type_column].append(
  75. prep_relation(r))
  76. return PreparedData(data.node_types, relation_types)