From aa3cc1f3ad62cd0480932d2beebac14892bb1821 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 28 Jul 2020 10:20:09 +0200 Subject: [PATCH] Reimplement FastConvLayer. --- src/icosagon/fastconv.py | 137 +++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 63 deletions(-) diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 8838c9a..038e2fc 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -136,11 +136,6 @@ class FastGraphConv(torch.nn.Module): class FastConvLayer(torch.nn.Module): - adjacency_matrix: List[torch.Tensor] - adjacency_matrix_backward: List[torch.Tensor] - weight: List[torch.Tensor] - weight_backward: List[torch.Tensor] - def __init__(self, input_dim: List[int], output_dim: List[int], @@ -162,70 +157,86 @@ class FastConvLayer(torch.nn.Module): self.rel_activation = rel_activation self.layer_activation = layer_activation - self.adjacency_matrix = None - self.adjacency_matrix_backward = None - self.weight = None - self.weight_backward = None + self.is_sparse = False + self.next_layer_repr = None self.build() def build(self): - self.adjacency_matrix = [] - self.adjacency_matrix_backward = [] - self.weight = [] - self.weight_backward = [] + self.next_layer_repr = torch.nn.ModuleList([ + torch.nn.ModuleList() \ + for _ in range(len(self.data.node_types)) + ]) for fam in self.data.relation_families: - adj_mat = [ rel.adjacency_matrix \ - for rel in fam.relation_types \ - if rel.adjacency_matrix is not None ] - adj_mat_back = [ rel.adjacency_matrix_backward \ - for rel in fam.relation_types \ - if rel.adjacency_matrix_backward is not None ] - weight = [ init_glorot(self.input_dim[fam.node_type_column], - self.output_dim[fam.node_type_row]) \ - for _ in range(len(adj_mat)) ] - weight_back = [ init_glorot(self.input_dim[fam.node_type_column], - self.output_dim[fam.node_type_row]) \ - for _ in range(len(adj_mat_back)) ] - adj_mat = torch.cat(adj_mat) \ - if len(adj_mat) > 0 \ - else None - adj_mat_back = torch.cat(adj_mat_back) \ - if len(adj_mat_back) > 0 \ - else None - self.adjacency_matrix.append(adj_mat) - self.adjacency_matrix_backward.append(adj_mat_back) - self.weight.append(weight) - self.weight_backward.append(weight_back) + self.build_family(fam) + + def build_family(self, fam) -> None: + if fam.node_type_row == fam.node_type_column: + self.build_fam_one_node_type(fam) + else: + self.build_fam_two_node_types(fam) + + def build_fam_one_node_type(self, fam) -> None: + adjacency_matrices = [ + r.adjacency_matrix \ + for r in fam.relation_types + ] + conv = FastGraphConv(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row], + adjacency_matrices, + self.keep_prob, + self.rel_activation) + conv.input_node_type = fam.node_type_column + self.next_layer_repr[fam.node_type_row].append(conv) + + def build_fam_two_node_types(self, fam) -> None: + adjacency_matrices = [ + r.adjacency_matrix \ + for r in fam.relation_types \ + if r.adjacency_matrix is not None + ] + + adjacency_matrices_backward = [ + r.adjacency_matrix_backward \ + for r in fam.relation_types \ + if r.adjacency_matrix_backward is not None + ] + + conv = FastGraphConv(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row], + adjacency_matrices, + self.keep_prob, + self.rel_activation) + + conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], + self.output_dim[fam.node_type_column], + adjacency_matrices_backward, + self.keep_prob, + self.rel_activation) + + conv.input_node_type = fam.node_type_column + conv_backward.input_node_type = fam.node_type_row + + self.next_layer_repr[fam.node_type_row].append(conv) + self.next_layer_repr[fam.node_type_column].append(conv_backward) def forward(self, prev_layer_repr): - for i, fam in enumerate(self.data.relation_families): - repr_row = prev_layer_repr[fam.node_type_row] - repr_column = prev_layer_repr[fam.node_type_column] - - adj_mat = self.adjacency_matrix[i] - adj_mat_back = self.adjacency_matrix_backward[i] - - if adj_mat is not None: - x = dropout(repr_column, keep_prob=self.keep_prob) - x = torch.sparse.mm(x, self.weight[i]) \ - if x.is_sparse \ - else torch.mm(x, self.weight[i]) - x = torch.sparse.mm(adj_mat, x) \ - if adj_mat.is_sparse \ - else torch.mm(adj_mat, x) - x = self.rel_activation(x) - x = x.view(len(fam.relation_types), len(repr_row), -1) - - if adj_mat_back is not None: - x = dropout(repr_column, keep_prob=self.keep_prob) - x = torch.sparse.mm(x, self.weight_backward[i]) \ - if x.is_sparse \ - else torch.mm(x, self.weight_backward[i]) - x = torch.sparse.mm(adj_mat_back, x) \ - if adj_mat_back.is_sparse \ - else torch.mm(adj_mat_back, x) - x = self.rel_activation(x) - x = x.view(len(fam.relation_types), len(repr_row), -1) + next_layer_repr = [ [] \ + for _ in range(len(self.data.node_types)) ] + for output_node_type in range(len(self.data.node_types)): + for conv in self.next_layer_repr[output_node_type]: + rep = conv(prev_layer_repr[conv.input_node_type]) + rep = torch.sum(rep, dim=0) + rep = torch.nn.functional.normalize(rep, p=2, dim=1) + next_layer_repr[output_node_type].append(rep) + if len(next_layer_repr[output_node_type]) == 0: + next_layer_repr[output_node_type] = \ + torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) + else: + next_layer_repr[output_node_type] = \ + sum(next_layer_repr[output_node_type]) + next_layer_repr[output_node_type] = \ + self.layer_activation(next_layer_repr[output_node_type]) + return next_layer_repr @staticmethod def _check_params(input_dim, output_dim, data, keep_prob,