IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

132 行
3.9KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from .sampling import fixed_unigram_candidate_sampler
  6. import torch
  7. from dataclasses import dataclass
  8. from typing import Any, \
  9. List, \
  10. Tuple, \
  11. Dict
  12. from .data import NodeType
  13. from collections import defaultdict
  14. from .normalize import norm_adj_mat_one_node_type, \
  15. norm_adj_mat_two_node_types
  16. import numpy as np
  17. @dataclass
  18. class TrainValTest(object):
  19. train: Any
  20. val: Any
  21. test: Any
  22. @dataclass
  23. class PreparedEdges(object):
  24. positive: TrainValTest
  25. negative: TrainValTest
  26. @dataclass
  27. class PreparedRelationType(object):
  28. name: str
  29. node_type_row: int
  30. node_type_column: int
  31. adj_mat_train: torch.Tensor
  32. edges_pos: TrainValTest
  33. edges_neg: TrainValTest
  34. @dataclass
  35. class PreparedData(object):
  36. node_types: List[NodeType]
  37. relation_types: Dict[int, Dict[int, List[PreparedRelationType]]]
  38. def train_val_test_split_edges(edges: torch.Tensor,
  39. ratios: TrainValTest) -> TrainValTest:
  40. if not isinstance(edges, torch.Tensor):
  41. raise ValueError('edges must be a torch.Tensor')
  42. if len(edges.shape) != 2 or edges.shape[1] != 2:
  43. raise ValueError('edges shape must be (num_edges, 2)')
  44. if not isinstance(ratios, TrainValTest):
  45. raise ValueError('ratios must be a TrainValTest')
  46. if ratios.train + ratios.val + ratios.test != 1.0:
  47. raise ValueError('Train, validation and test ratios must add up to 1')
  48. order = torch.randperm(len(edges))
  49. edges = edges[order, :]
  50. n = round(len(edges) * ratios.train)
  51. edges_train = edges[:n]
  52. n_1 = round(len(edges) * (ratios.train + ratios.val))
  53. edges_val = edges[n:n_1]
  54. edges_test = edges[n_1:]
  55. return TrainValTest(edges_train, edges_val, edges_test)
  56. def get_edges_and_degrees(adj_mat):
  57. if adj_mat.is_sparse:
  58. adj_mat = adj_mat.coalesce()
  59. degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64)
  60. degrees = degrees.index_add(0, adj_mat.indices()[1],
  61. torch.ones(adj_mat.indices().shape[1], dtype=torch.int64))
  62. edges_pos = adj_mat.indices().transpose(0, 1)
  63. else:
  64. degrees = adj_mat.sum(0)
  65. edges_pos = torch.nonzero(adj_mat)
  66. return edges_pos, degrees
  67. def prepare_adj_mat(adj_mat: torch.Tensor,
  68. ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
  69. if not isinstance(adj_mat, torch.Tensor):
  70. raise ValueError('adj_mat must be a torch.Tensor')
  71. edges_pos, degrees = get_edges_and_degrees(adj_mat)
  72. neg_neighbors = fixed_unigram_candidate_sampler(
  73. edges_pos[:, 1].view(-1, 1), degrees, 0.75)
  74. print(edges_pos.dtype)
  75. print(neg_neighbors.dtype)
  76. edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
  77. edges_pos = train_val_test_split_edges(edges_pos, ratios)
  78. edges_neg = train_val_test_split_edges(edges_neg, ratios)
  79. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
  80. values=torch.ones(len(edges_pos.train), dtype=adj_mat.dtype))
  81. return adj_mat_train, edges_pos, edges_neg
  82. def prepare_relation(r, ratios):
  83. adj_mat = r.adjacency_matrix
  84. adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  85. if r.node_type_row == r.node_type_column:
  86. adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
  87. else:
  88. adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
  89. return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
  90. adj_mat_train, edges_pos, edges_neg)
  91. def prepare_training(data):
  92. relation_types = defaultdict(lambda: defaultdict(list))
  93. for (node_type_row, node_type_column), rels in data.relation_types:
  94. for r in rels:
  95. relation_types[node_type_row][node_type_column].append(
  96. prep_relation(r))
  97. return PreparedData(data.node_types, relation_types)