From 82d5c06eee4c68028e6dba621b42bd8ce809527d Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 6 Aug 2020 15:25:50 +0200 Subject: [PATCH] Fix aggregation of representations from different edge types. --- src/triacontagon/model.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py index 49ca1bb..49d37bd 100644 --- a/src/triacontagon/model.py +++ b/src/triacontagon/model.py @@ -133,9 +133,10 @@ class Model(torch.nn.Module): List[torch.Tensor]: cur_layer_repr = in_layer_repr - next_layer_repr = [ None ] * len(self.data.vertex_types) - + for i in range(len(self.layer_dimensions) - 1): + next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] + for _, et in self.data.edge_types.items(): vt_row, vt_col = et.vertex_type_row, et.vertex_type_column adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)] @@ -158,13 +159,18 @@ class Model(torch.nn.Module): self.data.vertex_types[vt_row].count, self.layer_dimensions[i + 1]) - print('b, Layer:', i, 'x.shape:', x.shape) + print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) x = x.sum(dim=0) x = torch.nn.functional.normalize(x, p=2, dim=1) - x = self.conv_activation(x) + # x = self.rel_activation(x) + print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) + + next_layer_repr[vt_row].append(x) + + next_layer_repr = [ self.conv_activation(sum(x)) \ + for x in next_layer_repr ] - next_layer_repr[vt_row] = x cur_layer_repr = next_layer_repr return next_layer_repr