IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

90 lines
2.9KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from dataclasses import dataclass
  6. from typing import Callable, \
  7. Tuple, \
  8. List
  9. import types
  10. from .util import _nonzero_sum, \
  11. _diag
  12. import torch
  13. @dataclass
  14. class DecodingMatrices(object):
  15. global_interaction: torch.Tensor
  16. local_variation: torch.Tensor
  17. @dataclass
  18. class VertexType(object):
  19. name: str
  20. count: int
  21. @dataclass
  22. class EdgeType(object):
  23. name: str
  24. vertex_type_row: int
  25. vertex_type_column: int
  26. adjacency_matrices: List[torch.Tensor]
  27. decoder_factory: Callable[[], DecodingMatrices]
  28. total_connectivity: torch.Tensor
  29. class Data(object):
  30. vertex_types: List[VertexType]
  31. edge_types: List[EdgeType]
  32. def __init__(self, target_value: int = 1) -> None:
  33. self.vertex_types = []
  34. self.edge_types = {}
  35. self.target_value = int(target_value)
  36. def add_vertex_type(self, name: str, count: int) -> None:
  37. name = str(name)
  38. count = int(count)
  39. if not name:
  40. raise ValueError('You must provide a non-empty vertex type name')
  41. if count <= 0:
  42. raise ValueError('You must provide a positive vertex count')
  43. self.vertex_types.append(VertexType(name, count))
  44. def add_edge_type(self, name: str,
  45. vertex_type_row: int, vertex_type_column: int,
  46. adjacency_matrices: List[torch.Tensor],
  47. decoder_factory: Callable[[], DecodingMatrices]) -> None:
  48. name = str(name)
  49. vertex_type_row = int(vertex_type_row)
  50. vertex_type_column = int(vertex_type_column)
  51. if not isinstance(adjacency_matrices, list):
  52. raise TypeError('adjacency_matrices must be a list of tensors')
  53. if not callable(decoder_factory):
  54. raise TypeError('decoder_factory must be callable')
  55. if (vertex_type_row, vertex_type_column) in self.edge_types:
  56. raise KeyError('Edge type for given combination of row and column already exists')
  57. if vertex_type_row == vertex_type_column and \
  58. any(torch.any(_diag(adj_mat).to(torch.bool)) \
  59. for adj_mat in adjacency_matrices):
  60. raise ValueError('Adjacency matrices for same row/column vertex types must have empty diagonals')
  61. if any(adj_mat.shape[0] != self.vertex_types[vertex_type_row].count \
  62. or adj_mat.shape[1] != self.vertex_types[vertex_type_column].count \
  63. for adj_mat in adjacency_matrices):
  64. raise ValueError('Adjacency matrices must have as many rows as row vertex type count and as many columns as column vertex type count')
  65. total_connectivity = _nonzero_sum(adjacency_matrices)
  66. self.edge_types[vertex_type_row, vertex_type_column] = \
  67. EdgeType(name, vertex_type_row, vertex_type_column,
  68. adjacency_matrices, decoder_factory, total_connectivity)