IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.9KB

  1. from .sampling import fixed_unigram_candidate_sampler
  2. import torch
  3. def train_val_test_split_edges(edges, ratios):
  4. train_ratio, val_ratio, test_ratio = ratios
  5. if train_ratio + val_ratio + test_ratio != 1.0:
  6. raise ValueError('Train, validation and test ratios must add up to 1')
  7. order = torch.randperm(len(edges))
  8. edges = edges[order, :]
  9. n = round(len(edges) * train_ratio)
  10. edges_train = edges[:n]
  11. n_1 = round(len(edges) * (train_ratio + val_ratio))
  12. edges_val = edges[n:n_1]
  13. edges_test = edges[n_1:]
  14. return edges_train, edges_val, edges_test
  15. def prepare_adj_mat(adj_mat, ratios):
  16. degrees = adj_mat.sum(0)
  17. edges_pos = torch.nonzero(adj_mat)
  18. neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1],
  19. len(edges), degrees, 0.75)
  20. edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1)
  21. edges_pos = (edges_pos_train, edges_pos_val, edges_pos_test) = \
  22. train_val_test_split_edges(edges_pos, ratios)
  23. edges_neg = (edges_neg_train, edges_neg_val, edges_neg_test) = \
  24. train_val_test_split_edges(edges_neg, ratios)
  25. return edges_pos, edges_neg
  26. class PreparedRelation(object):
  27. def __init__(self, node_type_row, node_type_column,
  28. adj_mat_train, adj_mat_train_trans,
  29. edges_pos, edges_neg, edges_pos_trans, edges_neg_trans):
  30. self.adj_mat_train = adj_mat_train
  31. self.adj_mat_train_trans = adj_mat_train_trans
  32. self.edges_pos = edges_pos
  33. self.edges_neg = edges_neg
  34. self.edges_pos_trans = edges_pos_trans
  35. self.edges_neg_trans = edges_neg_trans
  36. def prepare_relation(r, ratios):
  37. adj_mat = r.get_adjacency_matrix(r.node_type_row, r.node_type_column)
  38. edges_pos, edges_neg = prepare_adj_mat(adj_mat)
  39. # adj_mat_train = torch.zeros_like(adj_mat)
  40. # adj_mat_train[edges_pos[0][:, 0], edges_pos[0][:, 0]] = 1
  41. adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos[0].transpose(0, 1),
  42. values=torch.ones(len(edges_pos[0]), dtype=adj_mat.dtype))
  43. if r.node_type_row != r.node_type_col:
  44. adj_mat_trans = r.get_adjacency_matrix(r.node_type_col, r.node_type_row)
  45. edges_pos_trans, edges_neg_trans = prepare_adj_mat(adj_mat_trans)
  46. adj_mat_train_trans = torch.sparse_coo_tensor(indices = edges_pos_trans[0].transpose(0, 1),
  47. values=torch.ones(len(edges_pos_trans[0]), dtype=adj_mat_trans.dtype))
  48. else:
  49. adj_mat_train_trans = adj_mat_trans = \
  50. edge_pos_trans = edge_neg_trans = None
  51. return PreparedRelation(r.node_type_row, r.node_type_column,
  52. adj_mat_train, adj_mat_trans_train,
  53. edges_pos, edges_neg, edges_pos_trans, edges_neg_trans)
  54. def prepare_training(data):
  55. for (node_type_row, node_type_column), rels in data.relation_types:
  56. for r in rels:
  57. prep_relation_edges()