From 60d8a43c12ece902984496ff82e2835d9c16a130 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Wed, 27 May 2020 18:11:32 +0200 Subject: [PATCH] Fix the DecagonLayer logic. --- src/decagon_pytorch/layer.py | 43 +++++++++--------------------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py index 61911ff..76f2721 100644 --- a/src/decagon_pytorch/layer.py +++ b/src/decagon_pytorch/layer.py @@ -85,51 +85,28 @@ class DecagonLayer(Layer): self.keep_prob = keep_prob self.rel_activation = rel_activation self.layer_activation = layer_activation - self.convolutions = None + self.next_layer_repr = None self.build() def build(self): - self.convolutions = {} - for (node_type_row, node_type_column) in self.data.relation_types.keys(): - adjacency_matrices = \ - self.data.get_adjacency_matrices(node_type_row, node_type_column) - self.convolutions[node_type_row, node_type_column] = SparseMultiDGCA(self.input_dim, - self.output_dim, adjacency_matrices, - self.keep_prob, self.rel_activation) - - # for node_type_row, node_type_col in enumerate(self.data.node_ - # if rt.node_type_row == i or rt.node_type_col == i: - - def __call__(self): - prev_layer_repr = self.previous_layer() - next_layer_repr = defaultdict(list) + self.next_layer_repr = defaultdict(list) for (nt_row, nt_col), rel in self.data.relation_types.items(): conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col], self.output_dim[nt_row], rel.adjacency_matrix, self.keep_prob, self.rel_activation) - next_layer_repr[nt_row].append(conv) + self.next_layer_repr[nt_row].append((conv, nt_col)) conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row], self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), self.keep_prob, self.rel_activation) - next_layer_repr[nt_col].append(conv) + self.next_layer_repr[nt_col].append((conv, nt_row)) + def __call__(self): + prev_layer_repr = self.previous_layer() + next_layer_repr = self.next_layer_repr + for i in range(len(self.data.node_types)): + next_layer_repr[i] = map(lambda conv, neighbor_type: \ + conv(prev_layer_repr[neighbor_type]), next_layer_repr[i]) next_layer_repr = list(map(sum, next_layer_repr)) return next_layer_repr - - - #for i, nt in enumerate(self.data.node_types): - # new_repr = [] - # for nt_row, nt_col in self.data.relation_types.keys(): - # if nt_row != i and nt_col != i: - # continue - # if nt_row == i: - # x = prev_layer_repr[nt_col] - # else: - # x = prev_layer_repr[nt_row] - # conv = self.convolutions[key] - # new_repr.append(conv(x)) - # new_repr = sum(new_repr) - # new_layer_repr.append(new_repr) - # return new_layer_repr