From bd894a02572531a43cceac114f8b7c2299052a2b Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Sun, 7 Jun 2020 18:10:14 +0200 Subject: [PATCH] Move back to using single-level dictionary for Data.relation_types. --- src/icosagon/convlayer.py | 33 ++++++++---------------- src/icosagon/data.py | 38 +++++++++++---------------- src/icosagon/declayer.py | 44 ++++++++++++-------------------- src/icosagon/trainprep.py | 8 +++--- tests/icosagon/test_convlayer.py | 4 +-- 5 files changed, 49 insertions(+), 78 deletions(-) diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index 88f15b8..d356afc 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -51,34 +51,23 @@ class DecagonLayer(torch.nn.Module): self.build() def build(self): - n = len(self.data.node_types) - rel_types = self.data.relation_types - - self.next_layer_repr = [ [] for _ in range(n) ] + self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] - for node_type_row in range(n): - if node_type_row not in rel_types: + for (node_type_row, node_type_column), rels in self.data.relation_types.items(): + if len(rels) == 0: continue - for node_type_column in range(n): - if node_type_column not in rel_types[node_type_row]: - continue - - rels = rel_types[node_type_row][node_type_column] - if len(rels) == 0: - continue - - convolutions = [] + convolutions = [] - for r in rels: - conv = DropoutGraphConvActivation(self.input_dim[node_type_column], - self.output_dim[node_type_row], r.adjacency_matrix, - self.keep_prob, self.rel_activation) + for r in rels: + conv = DropoutGraphConvActivation(self.input_dim[node_type_column], + self.output_dim[node_type_row], r.adjacency_matrix, + self.keep_prob, self.rel_activation) - convolutions.append(conv) + convolutions.append(conv) - self.next_layer_repr[node_type_row].append( - Convolutions(node_type_column, convolutions)) + self.next_layer_repr[node_type_row].append( + Convolutions(node_type_column, convolutions)) def __call__(self, prev_layer_repr): next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 4166d40..b52dadc 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -7,6 +7,9 @@ from collections import defaultdict from dataclasses import dataclass import torch +from typing import List, \ + Dict, \ + Tuple @dataclass @@ -25,9 +28,12 @@ class RelationType(object): class Data(object): + node_types: List[NodeType] + relation_types: Dict[Tuple[int, int], List[RelationType]] + def __init__(self) -> None: self.node_types = [] - self.relation_types = defaultdict(lambda: defaultdict(list)) + self.relation_types = defaultdict(list) def add_node_type(self, name: str, count: int) -> None: name = str(name) @@ -73,14 +79,14 @@ class Data(object): adjacency_matrix_backward is not None: raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') - self.relation_types[node_type_row][node_type_column].append( + self.relation_types[node_type_row, node_type_column].append( RelationType(name, node_type_row, node_type_column, adjacency_matrix, False)) if node_type_row != node_type_column and two_way: if adjacency_matrix_backward is None: adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) - self.relation_types[node_type_column][node_type_row].append( + self.relation_types[node_type_column, node_type_row].append( RelationType(name, node_type_column, node_type_row, adjacency_matrix_backward, True)) @@ -99,16 +105,15 @@ class Data(object): s_1 = '' count = 0 - for i in range(n): - for j in range(n): - if i not in self.relation_types or \ - j not in self.relation_types[i]: + for node_type_row in range(n): + for node_type_column in range(n): + if (node_type_row, node_type_column) not in self.relation_types: continue - s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ - self.node_types[j].name + ':\n' + s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \ + self.node_types[node_type_column].name + ':\n' - for r in self.relation_types[i][j]: + for r in self.relation_types[node_type_row, node_type_column]: if r.is_autogenerated: continue s_1 += ' - ' + r.name + '\n' @@ -118,16 +123,3 @@ class Data(object): s += s_1 return s.strip() - - # n = sum(map(len, self.relation_types)) - # - # for i in range(n): - # for j in range(n): - # key = (i, j) - # if key not in self.relation_types: - # continue - # rels = self.relation_types[key] - # - # for r in rels: - # - # return s.strip() diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 68dfdd1..5b85c06 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -44,37 +44,27 @@ class DecodeLayer(torch.nn.Module): def build(self) -> None: self.decoders = {} - n = len(self.data.node_types) - relation_types = self.data.relation_types - for node_type_row in range(n): - if node_type_row not in relation_types: + for (node_type_row, node_type_column), rels in self.data.relation_types.items(): + if len(rels) == 0: continue - for node_type_column in range(n): - if node_type_column not in relation_types[node_type_row]: - continue - - rels = relation_types[node_type_row][node_type_column] - if len(rels) == 0: - continue - - if isinstance(self.decoder_class, dict): - if (node_type_row, node_type_column) in self.decoder_class: - decoder_class = self.decoder_class[node_type_row, node_type_column] - elif (node_type_column, node_type_row) in self.decoder_class: - decoder_class = self.decoder_class[node_type_column, node_type_row] - else: - raise KeyError('Decoder not specified for edge type: %s -- %s' % ( - self.data.node_types[node_type_row].name, - self.data.node_types[node_type_column].name)) + if isinstance(self.decoder_class, dict): + if (node_type_row, node_type_column) in self.decoder_class: + decoder_class = self.decoder_class[node_type_row, node_type_column] + elif (node_type_column, node_type_row) in self.decoder_class: + decoder_class = self.decoder_class[node_type_column, node_type_row] else: - decoder_class = self.decoder_class + raise KeyError('Decoder not specified for edge type: %s -- %s' % ( + self.data.node_types[node_type_row].name, + self.data.node_types[node_type_column].name)) + else: + decoder_class = self.decoder_class - self.decoders[node_type_row, node_type_column] = \ - decoder_class(self.input_dim[node_type_row], - num_relation_types = len(rels), - keep_prob = self.keep_prob, - activation = self.activation) + self.decoders[node_type_row, node_type_column] = \ + decoder_class(self.input_dim[node_type_row], + num_relation_types = len(rels), + keep_prob = self.keep_prob, + activation = self.activation) def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: res = {} diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py index 10886a1..a979290 100644 --- a/src/icosagon/trainprep.py +++ b/src/icosagon/trainprep.py @@ -46,7 +46,7 @@ class PreparedRelationType(object): @dataclass class PreparedData(object): node_types: List[NodeType] - relation_types: Dict[int, Dict[int, List[PreparedRelationType]]] + relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] def train_val_test_split_edges(edges: torch.Tensor, @@ -137,9 +137,9 @@ def prepare_training(data: Data) -> PreparedData: if not isinstance(data, Data): raise ValueError('data must be of class Data') - relation_types = defaultdict(lambda: defaultdict(list)) - for (node_type_row, node_type_column), rels in data.relation_types: + relation_types = defaultdict(list) + for (node_type_row, node_type_column), rels in data.relation_types.items(): for r in rels: - relation_types[node_type_row][node_type_column].append( + relation_types[node_type_row, node_type_column].append( prep_relation_type(r)) return PreparedData(data.node_types, relation_types) diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py index 82b7f56..8e713c2 100644 --- a/tests/icosagon/test_convlayer.py +++ b/tests/icosagon/test_convlayer.py @@ -86,7 +86,7 @@ def test_decagon_layer_04(): in_layer = OneHotInputLayer(d) multi_dgca = MultiDGCA([10], 32, - [r.adjacency_matrix for r in d.relation_types[0][0]], + [r.adjacency_matrix for r in d.relation_types[0, 0]], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(in_layer.output_dim, 32, d, @@ -129,7 +129,7 @@ def test_decagon_layer_05(): in_layer = OneHotInputLayer(d) multi_dgca = MultiDGCA([100, 100], 32, - [r.adjacency_matrix for r in d.relation_types[0][0]], + [r.adjacency_matrix for r in d.relation_types[0, 0]], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,