IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.8KB

  1. from .matrix import NodeType
  2. import torch
  3. from collections import defaultdict
  4. class AdjListRelationType(object):
  5. def __init__(self, name, node_type_row, node_type_column,
  6. adjacency_list, adjacency_list_transposed=None):
  7. #if adjacency_matrix_transposed is not None and \
  8. # adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
  9. # raise ValueError('adjacency_matrix_transposed has incorrect shape')
  10. self.name = name
  11. self.node_type_row = node_type_row
  12. self.node_type_column = node_type_column
  13. self.adjacency_list = adjacency_list
  14. self.adjacency_list_transposed = adjacency_list_transposed
  15. def get_adjacency_list(self, node_type_row, node_type_column):
  16. if self.node_type_row == node_type_row and \
  17. self.node_type_column == node_type_column:
  18. return self.adjacency_list
  19. elif self.node_type_row == node_type_column and \
  20. self.node_type_column == node_type_row:
  21. if self.adjacency_list_transposed is not None:
  22. return self.adjacency_list_transposed
  23. else:
  24. return torch.index_select(self.adjacency_list, 1,
  25. torch.LongTensor([1, 0]))
  26. else:
  27. raise ValueError('Specified row/column types do not correspond to this relation')
  28. def _verify_adjacency_list(adjacency_list, node_count_row, node_count_col):
  29. assert isinstance(adjacency_list, torch.Tensor)
  30. assert len(adjacency_list.shape) == 2
  31. assert torch.all(adjacency_list[:, 0] >= 0)
  32. assert torch.all(adjacency_list[:, 0] < node_count_row)
  33. assert torch.all(adjacency_list[:, 1] >= 0)
  34. assert torch.all(adjacency_list[:, 1] < node_count_col)
  35. class AdjListData(object):
  36. def __init__(self):
  37. self.node_types = []
  38. self.relation_types = defaultdict(list)
  39. def add_node_type(self, name, count): # , latent_length):
  40. self.node_types.append(NodeType(name, count))
  41. def add_relation_type(self, name, node_type_row, node_type_col, adjacency_list, adjacency_list_transposed=None):
  42. assert node_type_row >= 0 and node_type_row < len(self.node_types)
  43. assert node_type_col >= 0 and node_type_col < len(self.node_types)
  44. node_count_row = self.node_types[node_type_row].count
  45. node_count_col = self.node_types[node_type_col].count
  46. _verify_adjacency_list(adjacency_list, node_count_row, node_count_col)
  47. if adjacency_list_transposed is not None:
  48. _verify_adjacency_list(adjacency_list_transposed,
  49. node_count_col, node_count_row)
  50. self.relation_types[node_type_row, node_type_col].append(
  51. AdjListRelationType(name, node_type_row, node_type_col,
  52. adjacency_list, adjacency_list_transposed))