| @@ -136,11 +136,6 @@ class FastGraphConv(torch.nn.Module): | |||||
| class FastConvLayer(torch.nn.Module): | class FastConvLayer(torch.nn.Module): | ||||
| adjacency_matrix: List[torch.Tensor] | |||||
| adjacency_matrix_backward: List[torch.Tensor] | |||||
| weight: List[torch.Tensor] | |||||
| weight_backward: List[torch.Tensor] | |||||
| def __init__(self, | def __init__(self, | ||||
| input_dim: List[int], | input_dim: List[int], | ||||
| output_dim: List[int], | output_dim: List[int], | ||||
| @@ -162,70 +157,86 @@ class FastConvLayer(torch.nn.Module): | |||||
| self.rel_activation = rel_activation | self.rel_activation = rel_activation | ||||
| self.layer_activation = layer_activation | self.layer_activation = layer_activation | ||||
| self.adjacency_matrix = None | |||||
| self.adjacency_matrix_backward = None | |||||
| self.weight = None | |||||
| self.weight_backward = None | |||||
| self.is_sparse = False | |||||
| self.next_layer_repr = None | |||||
| self.build() | self.build() | ||||
| def build(self): | def build(self): | ||||
| self.adjacency_matrix = [] | |||||
| self.adjacency_matrix_backward = [] | |||||
| self.weight = [] | |||||
| self.weight_backward = [] | |||||
| self.next_layer_repr = torch.nn.ModuleList([ | |||||
| torch.nn.ModuleList() \ | |||||
| for _ in range(len(self.data.node_types)) | |||||
| ]) | |||||
| for fam in self.data.relation_families: | for fam in self.data.relation_families: | ||||
| adj_mat = [ rel.adjacency_matrix \ | |||||
| for rel in fam.relation_types \ | |||||
| if rel.adjacency_matrix is not None ] | |||||
| adj_mat_back = [ rel.adjacency_matrix_backward \ | |||||
| for rel in fam.relation_types \ | |||||
| if rel.adjacency_matrix_backward is not None ] | |||||
| weight = [ init_glorot(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row]) \ | |||||
| for _ in range(len(adj_mat)) ] | |||||
| weight_back = [ init_glorot(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row]) \ | |||||
| for _ in range(len(adj_mat_back)) ] | |||||
| adj_mat = torch.cat(adj_mat) \ | |||||
| if len(adj_mat) > 0 \ | |||||
| else None | |||||
| adj_mat_back = torch.cat(adj_mat_back) \ | |||||
| if len(adj_mat_back) > 0 \ | |||||
| else None | |||||
| self.adjacency_matrix.append(adj_mat) | |||||
| self.adjacency_matrix_backward.append(adj_mat_back) | |||||
| self.weight.append(weight) | |||||
| self.weight_backward.append(weight_back) | |||||
| self.build_family(fam) | |||||
| def build_family(self, fam) -> None: | |||||
| if fam.node_type_row == fam.node_type_column: | |||||
| self.build_fam_one_node_type(fam) | |||||
| else: | |||||
| self.build_fam_two_node_types(fam) | |||||
| def build_fam_one_node_type(self, fam) -> None: | |||||
| adjacency_matrices = [ | |||||
| r.adjacency_matrix \ | |||||
| for r in fam.relation_types | |||||
| ] | |||||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row], | |||||
| adjacency_matrices, | |||||
| self.keep_prob, | |||||
| self.rel_activation) | |||||
| conv.input_node_type = fam.node_type_column | |||||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||||
| def build_fam_two_node_types(self, fam) -> None: | |||||
| adjacency_matrices = [ | |||||
| r.adjacency_matrix \ | |||||
| for r in fam.relation_types \ | |||||
| if r.adjacency_matrix is not None | |||||
| ] | |||||
| adjacency_matrices_backward = [ | |||||
| r.adjacency_matrix_backward \ | |||||
| for r in fam.relation_types \ | |||||
| if r.adjacency_matrix_backward is not None | |||||
| ] | |||||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row], | |||||
| adjacency_matrices, | |||||
| self.keep_prob, | |||||
| self.rel_activation) | |||||
| conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], | |||||
| self.output_dim[fam.node_type_column], | |||||
| adjacency_matrices_backward, | |||||
| self.keep_prob, | |||||
| self.rel_activation) | |||||
| conv.input_node_type = fam.node_type_column | |||||
| conv_backward.input_node_type = fam.node_type_row | |||||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||||
| self.next_layer_repr[fam.node_type_column].append(conv_backward) | |||||
| def forward(self, prev_layer_repr): | def forward(self, prev_layer_repr): | ||||
| for i, fam in enumerate(self.data.relation_families): | |||||
| repr_row = prev_layer_repr[fam.node_type_row] | |||||
| repr_column = prev_layer_repr[fam.node_type_column] | |||||
| adj_mat = self.adjacency_matrix[i] | |||||
| adj_mat_back = self.adjacency_matrix_backward[i] | |||||
| if adj_mat is not None: | |||||
| x = dropout(repr_column, keep_prob=self.keep_prob) | |||||
| x = torch.sparse.mm(x, self.weight[i]) \ | |||||
| if x.is_sparse \ | |||||
| else torch.mm(x, self.weight[i]) | |||||
| x = torch.sparse.mm(adj_mat, x) \ | |||||
| if adj_mat.is_sparse \ | |||||
| else torch.mm(adj_mat, x) | |||||
| x = self.rel_activation(x) | |||||
| x = x.view(len(fam.relation_types), len(repr_row), -1) | |||||
| if adj_mat_back is not None: | |||||
| x = dropout(repr_column, keep_prob=self.keep_prob) | |||||
| x = torch.sparse.mm(x, self.weight_backward[i]) \ | |||||
| if x.is_sparse \ | |||||
| else torch.mm(x, self.weight_backward[i]) | |||||
| x = torch.sparse.mm(adj_mat_back, x) \ | |||||
| if adj_mat_back.is_sparse \ | |||||
| else torch.mm(adj_mat_back, x) | |||||
| x = self.rel_activation(x) | |||||
| x = x.view(len(fam.relation_types), len(repr_row), -1) | |||||
| next_layer_repr = [ [] \ | |||||
| for _ in range(len(self.data.node_types)) ] | |||||
| for output_node_type in range(len(self.data.node_types)): | |||||
| for conv in self.next_layer_repr[output_node_type]: | |||||
| rep = conv(prev_layer_repr[conv.input_node_type]) | |||||
| rep = torch.sum(rep, dim=0) | |||||
| rep = torch.nn.functional.normalize(rep, p=2, dim=1) | |||||
| next_layer_repr[output_node_type].append(rep) | |||||
| if len(next_layer_repr[output_node_type]) == 0: | |||||
| next_layer_repr[output_node_type] = \ | |||||
| torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) | |||||
| else: | |||||
| next_layer_repr[output_node_type] = \ | |||||
| sum(next_layer_repr[output_node_type]) | |||||
| next_layer_repr[output_node_type] = \ | |||||
| self.layer_activation(next_layer_repr[output_node_type]) | |||||
| return next_layer_repr | |||||
| @staticmethod | @staticmethod | ||||
| def _check_params(input_dim, output_dim, data, keep_prob, | def _check_params(input_dim, output_dim, data, keep_prob, | ||||