# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # from dataclasses import dataclass from typing import Callable, \ Tuple, \ List import types from .util import _nonzero_sum import torch @dataclass class DecodingMatrices(object): global_interaction: torch.Tensor local_variation: torch.Tensor @dataclass class VertexType(object): name: str count: int @dataclass class EdgeType(object): name: str vertex_type_row: int vertex_type_column: int adjacency_matrices: List[torch.Tensor] decoder_factory: Callable[[], DecodingMatrices] total_connectivity: torch.Tensor class Data(object): vertex_types: List[VertexType] edge_types: List[EdgeType] def __init__(self, target_value: int = 1) -> None: self.vertex_types = [] self.edge_types = {} self.target_value = int(target_value) def add_vertex_type(self, name: str, count: int) -> None: name = str(name) count = int(count) if not name: raise ValueError('You must provide a non-empty vertex type name') if count <= 0: raise ValueError('You must provide a positive vertex count') self.vertex_types.append(VertexType(name, count)) def add_edge_type(self, name: str, vertex_type_row: int, vertex_type_column: int, adjacency_matrices: List[torch.Tensor], decoder_factory: Callable[[], DecodingMatrices]) -> None: name = str(name) vertex_type_row = int(vertex_type_row) vertex_type_column = int(vertex_type_column) if not isinstance(adjacency_matrices, list): raise TypeError('adjacency_matrices must be a list of tensors') if not isinstance(decoder_factory, types.FunctionType): raise TypeError('decoder_factory must be a function') if (vertex_type_row, vertex_type_column) in self.edge_types: raise KeyError('Edge type for given combination of row and column already exists') total_connectivity = _nonzero_sum(adjacency_matrices) self.edge_types[vertex_type_row, vertex_type_column] = \ EdgeType(name, vertex_type_row, vertex_type_column, adjacency_matrices, decoder_factory, total_connectivity)