# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # from collections import defaultdict from dataclasses import dataclass, field import torch from typing import List, \ Dict, \ Tuple, \ Any, \ Type from .decode import DEDICOMDecoder, \ BilinearDecoder def _equal(x: torch.Tensor, y: torch.Tensor): if x.is_sparse ^ y.is_sparse: raise ValueError('Cannot mix sparse and dense tensors') if not x.is_sparse: return (x == y) x = x.coalesce() indices_x = list(map(tuple, x.indices().transpose(0, 1))) order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) y = y.coalesce() indices_y = list(map(tuple, y.indices().transpose(0, 1))) order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) if not indices_x == indices_y: return torch.tensor(0, dtype=torch.uint8) return (x.values()[order_x] == y.values()[order_y]) @dataclass class NodeType(object): name: str count: int @dataclass class RelationType(object): name: str node_type_row: int node_type_column: int adjacency_matrix: torch.Tensor two_way: bool hints: Dict[str, Any] = field(default_factory=dict) class RelationFamily(object): def __init__(self, data: 'Data', name: str, node_type_row: int, node_type_column: int, is_symmetric: bool, decoder_class: Type) -> None: if not is_symmetric and \ decoder_class != DEDICOMDecoder and \ decoder_class != BilinearDecoder: raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') self.data = data self.name = name self.node_type_row = node_type_row self.node_type_column = node_type_column self.is_symmetric = is_symmetric self.decoder_class = decoder_class self.relation_types = { (node_type_row, node_type_column): [], (node_type_column, node_type_row): [] } def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, two_way: bool = True) -> None: name = str(name) node_type_row = int(node_type_row) node_type_column = int(node_type_column) if (node_type_row, node_type_column) not in self.relation_types: raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') if node_type_row < 0 or node_type_row >= len(self.data.node_types): raise ValueError('node_type_row outside of the valid range of node types') if node_type_column < 0 or node_type_column >= len(self.data.node_types): raise ValueError('node_type_column outside of the valid range of node types') if not isinstance(adjacency_matrix, torch.Tensor): raise ValueError('adjacency_matrix must be a torch.Tensor') if adjacency_matrix_backward is not None \ and not isinstance(adjacency_matrix_backward, torch.Tensor): raise ValueError('adjacency_matrix_backward must be a torch.Tensor') if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, self.data.node_types[node_type_column].count): raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') if adjacency_matrix_backward is not None and \ adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, self.data.node_types[node_type_row].count): raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') if node_type_row == node_type_column and \ adjacency_matrix_backward is not None: raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') if self.is_symmetric and adjacency_matrix_backward is not None: raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') if self.is_symmetric and node_type_row == node_type_column and \ not torch.all(_equal(adjacency_matrix, adjacency_matrix.transpose(0, 1))): raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') two_way = bool(two_way) if node_type_row != node_type_column and two_way: print('%d != %d' % (node_type_row, node_type_column)) if adjacency_matrix_backward is None: adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) self.relation_types[node_type_column, node_type_row].append( RelationType(name, node_type_column, node_type_row, adjacency_matrix_backward, two_way, { 'display': False })) self.relation_types[node_type_row, node_type_column].append( RelationType(name, node_type_row, node_type_column, adjacency_matrix, two_way)) def node_name(self, index): return self.data.node_types[index].name def __repr__(self): s = 'Relation family %s' % self.name for (node_type_row, node_type_column), rels in self.relation_types.items(): for r in rels: if 'display' in r.hints and not r.hints['display']: continue s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ (self.node_name(node_type_row), self.node_name(node_type_column))) return s def repr_indented(self): s = ' - %s' % self.name for (node_type_row, node_type_column), rels in self.relation_types.items(): for r in rels: if 'display' in r.hints and not r.hints['display']: continue s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ (self.node_name(node_type_row), self.node_name(node_type_column))) return s class Data(object): node_types: List[NodeType] relation_types: Dict[Tuple[int, int], List[RelationType]] def __init__(self) -> None: self.node_types = [] self.relation_families = [] def add_node_type(self, name: str, count: int) -> None: name = str(name) count = int(count) if not name: raise ValueError('You must provide a non-empty node type name') if count <= 0: raise ValueError('You must provide a positive node count') self.node_types.append(NodeType(name, count)) def add_relation_family(self, name: str, node_type_row: int, node_type_column: int, is_symmetric: bool, decoder_class: Type = DEDICOMDecoder): fam = RelationFamily(self, name, node_type_row, node_type_column, is_symmetric, decoder_class) self.relation_families.append(fam) return fam def __repr__(self): n = len(self.node_types) if n == 0: return 'Empty Icosagon Data' s = '' s += 'Icosagon Data with:\n' s += '- ' + str(n) + ' node type(s):\n' for nt in self.node_types: s += ' - ' + nt.name + '\n' if len(self.relation_families) == 0: s += '- No relation families\n' return s.strip() s += '- %d relation families:\n' % len(self.relation_families) for fam in self.relation_families: s += fam.repr_indented() + '\n' return s.strip()