IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.3KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. from dataclasses import dataclass
  6. from typing import Callable, \
  7. Tuple, \
  8. List
  9. import types
  10. from .util import _nonzero_sum
  11. import torch
  12. @dataclass
  13. class DecodingMatrices(object):
  14. global_interaction: torch.Tensor
  15. local_variation: torch.Tensor
  16. @dataclass
  17. class VertexType(object):
  18. name: str
  19. count: int
  20. @dataclass
  21. class EdgeType(object):
  22. name: str
  23. vertex_type_row: int
  24. vertex_type_column: int
  25. adjacency_matrices: List[torch.Tensor]
  26. decoder_factory: Callable[[], DecodingMatrices]
  27. total_connectivity: torch.Tensor
  28. class Data(object):
  29. vertex_types: List[VertexType]
  30. edge_types: List[EdgeType]
  31. def __init__(self, target_value: int = 1) -> None:
  32. self.vertex_types = []
  33. self.edge_types = {}
  34. self.target_value = int(target_value)
  35. def add_vertex_type(self, name: str, count: int) -> None:
  36. name = str(name)
  37. count = int(count)
  38. if not name:
  39. raise ValueError('You must provide a non-empty vertex type name')
  40. if count <= 0:
  41. raise ValueError('You must provide a positive vertex count')
  42. self.vertex_types.append(VertexType(name, count))
  43. def add_edge_type(self, name: str,
  44. vertex_type_row: int, vertex_type_column: int,
  45. adjacency_matrices: List[torch.Tensor],
  46. decoder_factory: Callable[[], DecodingMatrices]) -> None:
  47. name = str(name)
  48. vertex_type_row = int(vertex_type_row)
  49. vertex_type_column = int(vertex_type_column)
  50. if not isinstance(adjacency_matrices, list):
  51. raise TypeError('adjacency_matrices must be a list of tensors')
  52. if not isinstance(decoder_factory, types.FunctionType):
  53. raise TypeError('decoder_factory must be a function')
  54. if (vertex_type_row, vertex_type_column) in self.edge_types:
  55. raise KeyError('Edge type for given combination of row and column already exists')
  56. total_connectivity = _nonzero_sum(adjacency_matrices)
  57. self.edge_types[vertex_type_row, vertex_type_column] = \
  58. EdgeType(name, vertex_type_row, vertex_type_column,
  59. adjacency_matrices, decoder_factory, total_connectivity)