IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
2.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from dataclasses import dataclass
  6. from typing import Callable, \
  7. Tuple, \
  8. List
  9. import types
  10. from .util import _nonzero_sum
  11. @dataclass
  12. class DecodingMatrices(object):
  13. global_interaction: torch.Tensor
  14. local_variation: torch.Tensor
  15. @dataclass
  16. class VertexType(object):
  17. name: str
  18. count: int
  19. @dataclass
  20. class EdgeType(object):
  21. name: str
  22. vertex_type_row: int
  23. vertex_type_column: int
  24. adjacency_matrices: List[torch.Tensor]
  25. decoder_factory: Callable[[], DecodingMatrices]
  26. total_connectivity: torch.Tensor
  27. class Data(object):
  28. vertex_types: List[VertexType]
  29. edge_types: List[EdgeType]
  30. def __init__(self) -> None:
  31. self.vertex_types = []
  32. self.edge_types = {}
  33. def add_vertex_type(self, name: str, count: int) -> None:
  34. name = str(name)
  35. count = int(count)
  36. if not name:
  37. raise ValueError('You must provide a non-empty vertex type name')
  38. if count <= 0:
  39. raise ValueError('You must provide a positive vertex count')
  40. self.vertex_types.append(VertexType(name, count))
  41. def add_edge_type(self, name: str,
  42. vertex_type_row: int, vertex_type_column: int,
  43. adjacency_matrices: List[torch.Tensor],
  44. decoder_factory: Callable[[], DecodingMatrices]) -> None:
  45. name = str(name)
  46. vertex_type_row = int(vertex_type_row)
  47. vertex_type_column = int(vertex_type_column)
  48. if not isinstance(adjacency_matrices, list):
  49. raise TypeError('adjacency_matrices must be a list of tensors')
  50. if not isinstance(decoder_factory, types.FunctionType):
  51. raise TypeError('decoder_factory must be a function')
  52. if (vertex_type_row, vertex_type_column) in self.edge_types:
  53. raise KeyError('Edge type for given combination of row and column already exists')
  54. total_connectivity = _nonzero_sum(adjacency_matrices)
  55. self.edges_types[vertex_type_row, vertex_type_column] = \
  56. VertexType(name, vertex_type_row, vertex_type_column,
  57. adjacency_matrices, decoder_factory, total_connectivity)