IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
4.9KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass, \
  8. field
  9. from typing import Any, \
  10. List, \
  11. Tuple, \
  12. Dict
  13. from .data import NodeType, \
  14. RelationType, \
  15. RelationTypeBase, \
  16. RelationFamily, \
  17. RelationFamilyBase, \
  18. Data
  19. from collections import defaultdict
  20. from .normalize import norm_adj_mat_one_node_type, \
  21. norm_adj_mat_two_node_types
  22. import numpy as np
  23. @dataclass
  24. class TrainValTest(object):
  25. train: Any
  26. val: Any
  27. test: Any
  28. @dataclass
  29. class PreparedRelationType(RelationTypeBase):
  30. edges_pos: TrainValTest
  31. edges_neg: TrainValTest
  32. @dataclass
  33. class PreparedRelationFamily(RelationFamilyBase):
  34. relation_types: Dict[Tuple[int, int], List[PreparedRelationType]]
  35. @dataclass
  36. class PreparedData(object):
  37. node_types: List[NodeType]
  38. relation_families: List[PreparedRelationFamily]
  39. def train_val_test_split_edges(edges: torch.Tensor,
  40. ratios: TrainValTest) -> TrainValTest:
  41. if not isinstance(edges, torch.Tensor):
  42. raise ValueError('edges must be a torch.Tensor')
  43. if len(edges.shape) != 2 or edges.shape[1] != 2:
  44. raise ValueError('edges shape must be (num_edges, 2)')
  45. if not isinstance(ratios, TrainValTest):
  46. raise ValueError('ratios must be a TrainValTest')
  47. if ratios.train + ratios.val + ratios.test != 1.0:
  48. raise ValueError('Train, validation and test ratios must add up to 1')
  49. order = torch.randperm(len(edges))
  50. edges = edges[order, :]
  51. n = round(len(edges) * ratios.train)
  52. edges_train = edges[:n]
  53. n_1 = round(len(edges) * (ratios.train + ratios.val))
  54. edges_val = edges[n:n_1]
  55. edges_test = edges[n_1:]
  56. return TrainValTest(edges_train, edges_val, edges_test)
  57. def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  58. if adj_mat.is_sparse:
  59. adj_mat = adj_mat.coalesce()
  60. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64)
  61. degrees = degrees.index_add(0, adj_mat.indices()[1],
  62. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64))
  63. edges_pos = adj_mat.indices().transpose(0, 1)
  64. else:
  65. degrees = adj_mat.sum(0)
  66. edges_pos = torch.nonzero(adj_mat)
  67. return edges_pos, degrees
  68. def prepare_adj_mat(adj_mat: torch.Tensor,
  69. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  70. if not isinstance(adj_mat, torch.Tensor):
  71. raise ValueError('adj_mat must be a torch.Tensor')
  72. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  73. neg_neighbors = fixed_unigram_candidate_sampler(
  74. edges_pos[:, 1].view(-1, 1), degrees, 0.75)
  75. print(edges_pos.dtype)
  76. print(neg_neighbors.dtype)
  77. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  78. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  79. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  80. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  81. values=torch.ones(len(edges_pos.train), dtype=adj_mat.dtype))
  82. return adj_mat_train, edges_pos, edges_neg
  83. def prepare_relation_type(r: RelationType,
  84. ratios: TrainValTest) -> PreparedRelationType:
  85. if not isinstance(r, RelationType):
  86. raise ValueError('r must be a RelationType')
  87. if not isinstance(ratios, TrainValTest):
  88. raise ValueError('ratios must be a TrainValTest')
  89. adj_mat = r.adjacency_matrix
  90. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
  91. print('adj_mat_train:', adj_mat_train)
  92. if r.node_type_row == r.node_type_column:
  93. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  94. else:
  95. adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  96. return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
  97. adj_mat_train, r.two_way, edges_pos, edges_neg)
  98. def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily:
  99. relation_types = { (fam.node_type_row, fam.node_type_column): [],
  100. (fam.node_type_column, fam.node_type_row): [] }
  101. for (node_type_row, node_type_column), rels in fam.relation_types.items():
  102. for r in rels:
  103. relation_types[node_type_row, node_type_column].append(
  104. prepare_relation_type(r, ratios))
  105. return PreparedRelationFamily(fam.data, fam.name,
  106. fam.node_type_row, fam.node_type_column,
  107. fam.is_symmetric, fam.decoder_class,
  108. relation_types)
  109. def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
  110. if not isinstance(data, Data):
  111. raise ValueError('data must be of class Data')
  112. relation_families = [ prepare_relation_family(fam) \
  113. for fam in data.relation_families ]
  114. return PreparedData(data.node_types, relation_families)