IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from dataclasses import dataclass
  6. from typing import Callable, \
  7. Tuple, \
  8. List
  9. import types
  10. from .util import _nonzero_sum
  11. import torch
  12. @dataclass
  13. class DecodingMatrices(object):
  14. global_interaction: torch.Tensor
  15. local_variation: torch.Tensor
  16. @dataclass
  17. class VertexType(object):
  18. name: str
  19. count: int
  20. @dataclass
  21. class EdgeType(object):
  22. name: str
  23. vertex_type_row: int
  24. vertex_type_column: int
  25. adjacency_matrices: List[torch.Tensor]
  26. decoder_factory: Callable[[], DecodingMatrices]
  27. total_connectivity: torch.Tensor
  28. class Data(object):
  29. vertex_types: List[VertexType]
  30. edge_types: List[EdgeType]
  31. def __init__(self) -> None:
  32. self.vertex_types = []
  33. self.edge_types = {}
  34. def add_vertex_type(self, name: str, count: int) -> None:
  35. name = str(name)
  36. count = int(count)
  37. if not name:
  38. raise ValueError('You must provide a non-empty vertex type name')
  39. if count <= 0:
  40. raise ValueError('You must provide a positive vertex count')
  41. self.vertex_types.append(VertexType(name, count))
  42. def add_edge_type(self, name: str,
  43. vertex_type_row: int, vertex_type_column: int,
  44. adjacency_matrices: List[torch.Tensor],
  45. decoder_factory: Callable[[], DecodingMatrices]) -> None:
  46. name = str(name)
  47. vertex_type_row = int(vertex_type_row)
  48. vertex_type_column = int(vertex_type_column)
  49. if not isinstance(adjacency_matrices, list):
  50. raise TypeError('adjacency_matrices must be a list of tensors')
  51. if not isinstance(decoder_factory, types.FunctionType):
  52. raise TypeError('decoder_factory must be a function')
  53. if (vertex_type_row, vertex_type_column) in self.edge_types:
  54. raise KeyError('Edge type for given combination of row and column already exists')
  55. total_connectivity = _nonzero_sum(adjacency_matrices)
  56. self.edge_types[vertex_type_row, vertex_type_column] = \
  57. EdgeType(name, vertex_type_row, vertex_type_column,
  58. adjacency_matrices, decoder_factory, total_connectivity)