IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

132 Zeilen
3.9KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass
  8. from typing import Any, \
  9. List, \
  10. Tuple, \
  11. Dict
  12. from .data import NodeType
  13. from collections import defaultdict
  14. from .normalize import norm_adj_mat_one_node_type, \
  15. norm_adj_mat_two_node_types
  16. import numpy as np
  17. @dataclass
  18. class TrainValTest(object):
  19. train: Any
  20. val: Any
  21. test: Any
  22. @dataclass
  23. class PreparedEdges(object):
  24. positive: TrainValTest
  25. negative: TrainValTest
  26. @dataclass
  27. class PreparedRelationType(object):
  28. name: str
  29. node_type_row: int
  30. node_type_column: int
  31. adj_mat_train: torch.Tensor
  32. edges_pos: TrainValTest
  33. edges_neg: TrainValTest
  34. @dataclass
  35. class PreparedData(object):
  36. node_types: List[NodeType]
  37. relation_types: Dict[int, Dict[int, List[PreparedRelationType]]]
  38. def train_val_test_split_edges(edges: torch.Tensor,
  39. ratios: TrainValTest) -> TrainValTest:
  40. if not isinstance(edges, torch.Tensor):
  41. raise ValueError('edges must be a torch.Tensor')
  42. if len(edges.shape) != 2 or edges.shape[1] != 2:
  43. raise ValueError('edges shape must be (num_edges, 2)')
  44. if not isinstance(ratios, TrainValTest):
  45. raise ValueError('ratios must be a TrainValTest')
  46. if ratios.train + ratios.val + ratios.test != 1.0:
  47. raise ValueError('Train, validation and test ratios must add up to 1')
  48. order = torch.randperm(len(edges))
  49. edges = edges[order, :]
  50. n = round(len(edges) * ratios.train)
  51. edges_train = edges[:n]
  52. n_1 = round(len(edges) * (ratios.train + ratios.val))
  53. edges_val = edges[n:n_1]
  54. edges_test = edges[n_1:]
  55. return TrainValTest(edges_train, edges_val, edges_test)
  56. def get_edges_and_degrees(adj_mat):
  57. if adj_mat.is_sparse:
  58. adj_mat = adj_mat.coalesce()
  59. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64)
  60. degrees = degrees.index_add(0, adj_mat.indices()[1],
  61. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64))
  62. edges_pos = adj_mat.indices().transpose(0, 1)
  63. else:
  64. degrees = adj_mat.sum(0)
  65. edges_pos = torch.nonzero(adj_mat)
  66. return edges_pos, degrees
  67. def prepare_adj_mat(adj_mat: torch.Tensor,
  68. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  69. if not isinstance(adj_mat, torch.Tensor):
  70. raise ValueError('adj_mat must be a torch.Tensor')
  71. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  72. neg_neighbors = fixed_unigram_candidate_sampler(
  73. edges_pos[:, 1].view(-1, 1), degrees, 0.75)
  74. print(edges_pos.dtype)
  75. print(neg_neighbors.dtype)
  76. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  77. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  78. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  79. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  80. values=torch.ones(len(edges_pos.train), dtype=adj_mat.dtype))
  81. return adj_mat_train, edges_pos, edges_neg
  82. def prepare_relation(r, ratios):
  83. adj_mat = r.adjacency_matrix
  84. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  85. if r.node_type_row == r.node_type_column:
  86. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  87. else:
  88. adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  89. return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  90. adj_mat_train, edges_pos, edges_neg)
  91. def prepare_training(data):
  92. relation_types = defaultdict(lambda: defaultdict(list))
  93. for (node_type_row, node_type_column), rels in data.relation_types:
  94. for r in rels:
  95. relation_types[node_type_row][node_type_column].append(
  96. prep_relation(r))
  97. return PreparedData(data.node_types, relation_types)