| @@ -136,11 +136,6 @@ class FastGraphConv(torch.nn.Module): | |||
| class FastConvLayer(torch.nn.Module): | |||
| adjacency_matrix: List[torch.Tensor] | |||
| adjacency_matrix_backward: List[torch.Tensor] | |||
| weight: List[torch.Tensor] | |||
| weight_backward: List[torch.Tensor] | |||
| def __init__(self, | |||
| input_dim: List[int], | |||
| output_dim: List[int], | |||
| @@ -162,70 +157,86 @@ class FastConvLayer(torch.nn.Module): | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.adjacency_matrix = None | |||
| self.adjacency_matrix_backward = None | |||
| self.weight = None | |||
| self.weight_backward = None | |||
| self.is_sparse = False | |||
| self.next_layer_repr = None | |||
| self.build() | |||
| def build(self): | |||
| self.adjacency_matrix = [] | |||
| self.adjacency_matrix_backward = [] | |||
| self.weight = [] | |||
| self.weight_backward = [] | |||
| self.next_layer_repr = torch.nn.ModuleList([ | |||
| torch.nn.ModuleList() \ | |||
| for _ in range(len(self.data.node_types)) | |||
| ]) | |||
| for fam in self.data.relation_families: | |||
| adj_mat = [ rel.adjacency_matrix \ | |||
| for rel in fam.relation_types \ | |||
| if rel.adjacency_matrix is not None ] | |||
| adj_mat_back = [ rel.adjacency_matrix_backward \ | |||
| for rel in fam.relation_types \ | |||
| if rel.adjacency_matrix_backward is not None ] | |||
| weight = [ init_glorot(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row]) \ | |||
| for _ in range(len(adj_mat)) ] | |||
| weight_back = [ init_glorot(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row]) \ | |||
| for _ in range(len(adj_mat_back)) ] | |||
| adj_mat = torch.cat(adj_mat) \ | |||
| if len(adj_mat) > 0 \ | |||
| else None | |||
| adj_mat_back = torch.cat(adj_mat_back) \ | |||
| if len(adj_mat_back) > 0 \ | |||
| else None | |||
| self.adjacency_matrix.append(adj_mat) | |||
| self.adjacency_matrix_backward.append(adj_mat_back) | |||
| self.weight.append(weight) | |||
| self.weight_backward.append(weight_back) | |||
| self.build_family(fam) | |||
| def build_family(self, fam) -> None: | |||
| if fam.node_type_row == fam.node_type_column: | |||
| self.build_fam_one_node_type(fam) | |||
| else: | |||
| self.build_fam_two_node_types(fam) | |||
| def build_fam_one_node_type(self, fam) -> None: | |||
| adjacency_matrices = [ | |||
| r.adjacency_matrix \ | |||
| for r in fam.relation_types | |||
| ] | |||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row], | |||
| adjacency_matrices, | |||
| self.keep_prob, | |||
| self.rel_activation) | |||
| conv.input_node_type = fam.node_type_column | |||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||
| def build_fam_two_node_types(self, fam) -> None: | |||
| adjacency_matrices = [ | |||
| r.adjacency_matrix \ | |||
| for r in fam.relation_types \ | |||
| if r.adjacency_matrix is not None | |||
| ] | |||
| adjacency_matrices_backward = [ | |||
| r.adjacency_matrix_backward \ | |||
| for r in fam.relation_types \ | |||
| if r.adjacency_matrix_backward is not None | |||
| ] | |||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row], | |||
| adjacency_matrices, | |||
| self.keep_prob, | |||
| self.rel_activation) | |||
| conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], | |||
| self.output_dim[fam.node_type_column], | |||
| adjacency_matrices_backward, | |||
| self.keep_prob, | |||
| self.rel_activation) | |||
| conv.input_node_type = fam.node_type_column | |||
| conv_backward.input_node_type = fam.node_type_row | |||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||
| self.next_layer_repr[fam.node_type_column].append(conv_backward) | |||
| def forward(self, prev_layer_repr): | |||
| for i, fam in enumerate(self.data.relation_families): | |||
| repr_row = prev_layer_repr[fam.node_type_row] | |||
| repr_column = prev_layer_repr[fam.node_type_column] | |||
| adj_mat = self.adjacency_matrix[i] | |||
| adj_mat_back = self.adjacency_matrix_backward[i] | |||
| if adj_mat is not None: | |||
| x = dropout(repr_column, keep_prob=self.keep_prob) | |||
| x = torch.sparse.mm(x, self.weight[i]) \ | |||
| if x.is_sparse \ | |||
| else torch.mm(x, self.weight[i]) | |||
| x = torch.sparse.mm(adj_mat, x) \ | |||
| if adj_mat.is_sparse \ | |||
| else torch.mm(adj_mat, x) | |||
| x = self.rel_activation(x) | |||
| x = x.view(len(fam.relation_types), len(repr_row), -1) | |||
| if adj_mat_back is not None: | |||
| x = dropout(repr_column, keep_prob=self.keep_prob) | |||
| x = torch.sparse.mm(x, self.weight_backward[i]) \ | |||
| if x.is_sparse \ | |||
| else torch.mm(x, self.weight_backward[i]) | |||
| x = torch.sparse.mm(adj_mat_back, x) \ | |||
| if adj_mat_back.is_sparse \ | |||
| else torch.mm(adj_mat_back, x) | |||
| x = self.rel_activation(x) | |||
| x = x.view(len(fam.relation_types), len(repr_row), -1) | |||
| next_layer_repr = [ [] \ | |||
| for _ in range(len(self.data.node_types)) ] | |||
| for output_node_type in range(len(self.data.node_types)): | |||
| for conv in self.next_layer_repr[output_node_type]: | |||
| rep = conv(prev_layer_repr[conv.input_node_type]) | |||
| rep = torch.sum(rep, dim=0) | |||
| rep = torch.nn.functional.normalize(rep, p=2, dim=1) | |||
| next_layer_repr[output_node_type].append(rep) | |||
| if len(next_layer_repr[output_node_type]) == 0: | |||
| next_layer_repr[output_node_type] = \ | |||
| torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) | |||
| else: | |||
| next_layer_repr[output_node_type] = \ | |||
| sum(next_layer_repr[output_node_type]) | |||
| next_layer_repr[output_node_type] = \ | |||
| self.layer_activation(next_layer_repr[output_node_type]) | |||
| return next_layer_repr | |||
| @staticmethod | |||
| def _check_params(input_dim, output_dim, data, keep_prob, | |||