|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- from collections import defaultdict
- from dataclasses import dataclass
- import torch
-
-
- @dataclass
- class NodeType(object):
- name: str
- count: int
-
-
- @dataclass
- class RelationType(object):
- name: str
- node_type_row: int
- node_type_column: int
- adjacency_matrix: torch.Tensor
- is_autogenerated: bool = False
-
-
- class Data(object):
- def __init__(self) -> None:
- self.node_types = []
- self.relation_types = defaultdict(lambda: defaultdict(list))
-
- def add_node_type(self, name: str, count: int) -> None:
- name = str(name)
- count = int(count)
- if not name:
- raise ValueError('You must provide a non-empty node type name')
- if count <= 0:
- raise ValueError('You must provide a positive node count')
- self.node_types.append(NodeType(name, count))
-
- def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
- adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
- two_way: bool = True) -> None:
-
- name = str(name)
- node_type_row = int(node_type_row)
- node_type_column = int(node_type_column)
-
- if node_type_row < 0 or node_type_row > len(self.node_types):
- raise ValueError('node_type_row outside of the valid range of node types')
-
- if node_type_column < 0 or node_type_column > len(self.node_types):
- raise ValueError('node_type_column outside of the valid range of node types')
-
- if not isinstance(adjacency_matrix, torch.Tensor):
- raise ValueError('adjacency_matrix must be a torch.Tensor')
-
- if adjacency_matrix_backward and not isinstance(adjacency_matrix_backward, torch.Tensor):
- raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
-
- if adjacency_matrix.shape != (self.node_types[node_type_row].count,
- self.node_types[node_type_column].count):
- raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
-
- if adjacency_matrix_backward is not None and \
- adjacency_matrix_backward.shape != (self.node_types[node_type_column].count,
- self.node_types[node_type_row].count):
- raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
-
- two_way = bool(two_way)
-
- if node_type_row == node_type_column and \
- adjacency_matrix_backward is not None:
- raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
-
- self.relation_types[node_type_row][node_type_column].append(
- RelationType(name, node_type_row, node_type_column,
- adjacency_matrix, False))
-
- if node_type_row != node_type_column and two_way:
- if adjacency_matrix_backward is None:
- adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
- self.relation_types[node_type_column][node_type_row].append(
- RelationType(name, node_type_column, node_type_row,
- adjacency_matrix_backward, True))
-
- def __repr__(self):
- n = len(self.node_types)
- if n == 0:
- return 'Empty Icosagon Data'
- s = ''
- s += 'Icosagon Data with:\n'
- s += '- ' + str(n) + ' node type(s):\n'
- for nt in self.node_types:
- s += ' - ' + nt.name + '\n'
- if len(self.relation_types) == 0:
- s += '- No relation types\n'
- return s.strip()
-
- s_1 = ''
- count = 0
- for i in range(n):
- for j in range(n):
- if i not in self.relation_types or \
- j not in self.relation_types[i]:
- continue
-
- s_1 += ' - ' + self.node_types[i].name + ' -- ' + \
- self.node_types[j].name + ':\n'
-
- for r in self.relation_types[i][j]:
- if r.is_autogenerated:
- continue
- s_1 += ' - ' + r.name + '\n'
- count += 1
-
- s += '- %d relation type(s):\n' % count
- s += s_1
-
- return s.strip()
-
- # n = sum(map(len, self.relation_types))
- #
- # for i in range(n):
- # for j in range(n):
- # key = (i, j)
- # if key not in self.relation_types:
- # continue
- # rels = self.relation_types[key]
- #
- # for r in rels:
- #
- # return s.strip()
|