IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

146 lignes
4.5KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass
  8. from typing import Any, \
  9. List, \
  10. Tuple, \
  11. Dict
  12. from .data import NodeType, \
  13. RelationType, \
  14. Data
  15. from collections import defaultdict
  16. from .normalize import norm_adj_mat_one_node_type, \
  17. norm_adj_mat_two_node_types
  18. import numpy as np
  19. @dataclass
  20. class TrainValTest(object):
  21. train: Any
  22. val: Any
  23. test: Any
  24. @dataclass
  25. class PreparedEdges(object):
  26. positive: TrainValTest
  27. negative: TrainValTest
  28. @dataclass
  29. class PreparedRelationType(object):
  30. name: str
  31. node_type_row: int
  32. node_type_column: int
  33. adjacency_matrix: torch.Tensor
  34. edges_pos: TrainValTest
  35. edges_neg: TrainValTest
  36. @dataclass
  37. class PreparedData(object):
  38. node_types: List[NodeType]
  39. relation_types: Dict[Tuple[int, int], List[PreparedRelationType]]
  40. def train_val_test_split_edges(edges: torch.Tensor,
  41. ratios: TrainValTest) -> TrainValTest:
  42. if not isinstance(edges, torch.Tensor):
  43. raise ValueError('edges must be a torch.Tensor')
  44. if len(edges.shape) != 2 or edges.shape[1] != 2:
  45. raise ValueError('edges shape must be (num_edges, 2)')
  46. if not isinstance(ratios, TrainValTest):
  47. raise ValueError('ratios must be a TrainValTest')
  48. if ratios.train + ratios.val + ratios.test != 1.0:
  49. raise ValueError('Train, validation and test ratios must add up to 1')
  50. order = torch.randperm(len(edges))
  51. edges = edges[order, :]
  52. n = round(len(edges) * ratios.train)
  53. edges_train = edges[:n]
  54. n_1 = round(len(edges) * (ratios.train + ratios.val))
  55. edges_val = edges[n:n_1]
  56. edges_test = edges[n_1:]
  57. return TrainValTest(edges_train, edges_val, edges_test)
  58. def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  59. if adj_mat.is_sparse:
  60. adj_mat = adj_mat.coalesce()
  61. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64)
  62. degrees = degrees.index_add(0, adj_mat.indices()[1],
  63. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64))
  64. edges_pos = adj_mat.indices().transpose(0, 1)
  65. else:
  66. degrees = adj_mat.sum(0)
  67. edges_pos = torch.nonzero(adj_mat)
  68. return edges_pos, degrees
  69. def prepare_adj_mat(adj_mat: torch.Tensor,
  70. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  71. if not isinstance(adj_mat, torch.Tensor):
  72. raise ValueError('adj_mat must be a torch.Tensor')
  73. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  74. neg_neighbors = fixed_unigram_candidate_sampler(
  75. edges_pos[:, 1].view(-1, 1), degrees, 0.75)
  76. print(edges_pos.dtype)
  77. print(neg_neighbors.dtype)
  78. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  79. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  80. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  81. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  82. values=torch.ones(len(edges_pos.train), dtype=adj_mat.dtype))
  83. return adj_mat_train, edges_pos, edges_neg
  84. def prepare_relation_type(r: RelationType,
  85. ratios: TrainValTest) -> PreparedRelationType:
  86. if not isinstance(r, RelationType):
  87. raise ValueError('r must be a RelationType')
  88. if not isinstance(ratios, TrainValTest):
  89. raise ValueError('ratios must be a TrainValTest')
  90. adj_mat = r.adjacency_matrix
  91. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  92. print('adj_mat_train:', adj_mat_train)
  93. if r.node_type_row == r.node_type_column:
  94. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  95. else:
  96. adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  97. return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
  98. adj_mat_train, edges_pos, edges_neg)
  99. def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
  100. if not isinstance(data, Data):
  101. raise ValueError('data must be of class Data')
  102. relation_types = defaultdict(list)
  103. for (node_type_row, node_type_column), rels in data.relation_types.items():
  104. for r in rels:
  105. relation_types[node_type_row, node_type_column].append(
  106. prepare_relation_type(r, ratios))
  107. return PreparedData(data.node_types, relation_types)