|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- from dataclasses import dataclass
- from typing import Callable, \
- Tuple, \
- List
- import types
- from .util import _nonzero_sum
- import torch
-
-
- @dataclass
- class DecodingMatrices(object):
- global_interaction: torch.Tensor
- local_variation: torch.Tensor
-
-
- @dataclass
- class VertexType(object):
- name: str
- count: int
-
-
- @dataclass
- class EdgeType(object):
- name: str
- vertex_type_row: int
- vertex_type_column: int
- adjacency_matrices: List[torch.Tensor]
- decoder_factory: Callable[[], DecodingMatrices]
- total_connectivity: torch.Tensor
-
-
- class Data(object):
- vertex_types: List[VertexType]
- edge_types: List[EdgeType]
-
- def __init__(self) -> None:
- self.vertex_types = []
- self.edge_types = {}
-
- def add_vertex_type(self, name: str, count: int) -> None:
- name = str(name)
- count = int(count)
- if not name:
- raise ValueError('You must provide a non-empty vertex type name')
- if count <= 0:
- raise ValueError('You must provide a positive vertex count')
- self.vertex_types.append(VertexType(name, count))
-
- def add_edge_type(self, name: str,
- vertex_type_row: int, vertex_type_column: int,
- adjacency_matrices: List[torch.Tensor],
- decoder_factory: Callable[[], DecodingMatrices]) -> None:
-
- name = str(name)
- vertex_type_row = int(vertex_type_row)
- vertex_type_column = int(vertex_type_column)
- if not isinstance(adjacency_matrices, list):
- raise TypeError('adjacency_matrices must be a list of tensors')
- if not isinstance(decoder_factory, types.FunctionType):
- raise TypeError('decoder_factory must be a function')
- if (vertex_type_row, vertex_type_column) in self.edge_types:
- raise KeyError('Edge type for given combination of row and column already exists')
- total_connectivity = _nonzero_sum(adjacency_matrices)
- self.edge_types[vertex_type_row, vertex_type_column] = \
- EdgeType(name, vertex_type_row, vertex_type_column,
- adjacency_matrices, decoder_factory, total_connectivity)
|