From 15e95bdc727b5710888add2d752000fd7d3da810 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Mon, 8 Jun 2020 10:54:32 +0200 Subject: [PATCH] Use hints instead of is_autogenerated. --- src/icosagon/data.py | 15 +++++++++------ src/icosagon/declayer.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/icosagon/data.py b/src/icosagon/data.py index b52dadc..86b17d1 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -5,11 +5,12 @@ from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field import torch from typing import List, \ Dict, \ - Tuple + Tuple, \ + Any @dataclass @@ -24,7 +25,7 @@ class RelationType(object): node_type_row: int node_type_column: int adjacency_matrix: torch.Tensor - is_autogenerated: bool = False + hints: Dict[str, Any] = field(default_factory=dict) class Data(object): @@ -81,14 +82,16 @@ class Data(object): self.relation_types[node_type_row, node_type_column].append( RelationType(name, node_type_row, node_type_column, - adjacency_matrix, False)) + adjacency_matrix)) if node_type_row != node_type_column and two_way: + hints = { 'display': False } if adjacency_matrix_backward is None: adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) + hints['symmetric'] = True self.relation_types[node_type_column, node_type_row].append( RelationType(name, node_type_column, node_type_row, - adjacency_matrix_backward, True)) + adjacency_matrix_backward, hints)) def __repr__(self): n = len(self.node_types) @@ -114,7 +117,7 @@ class Data(object): self.node_types[node_type_column].name + ':\n' for r in self.relation_types[node_type_row, node_type_column]: - if r.is_autogenerated: + if not r.hints.get('display', True): continue s_1 += ' - ' + r.name + '\n' count += 1 diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 5b85c06..f5ed5b1 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -71,6 +71,6 @@ class DecodeLayer(torch.nn.Module): for (node_type_row, node_type_column), dec in self.decoders.items(): inputs_row = last_layer_repr[node_type_row] inputs_column = last_layer_repr[node_type_column] - pred_adj_matrices = dec(inputs_row, inputs_column) + pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ] res[node_type_row, node_type_column] = pred_adj_matrices return res