|  |  | @@ -78,80 +78,61 @@ def _cat(matrices: List[torch.Tensor]): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class FastGraphConv(torch.nn.Module): | 
		
	
		
			
			|  |  |  | def __init__(self, | 
		
	
		
			
			|  |  |  | in_channels: List[int], | 
		
	
		
			
			|  |  |  | out_channels: List[int], | 
		
	
		
			
			|  |  |  | data: Union[Data, PreparedData], | 
		
	
		
			
			|  |  |  | relation_family: Union[RelationFamily, PreparedRelationFamily], | 
		
	
		
			
			|  |  |  | in_channels: int, | 
		
	
		
			
			|  |  |  | out_channels: int, | 
		
	
		
			
			|  |  |  | adjacency_matrices: List[torch.Tensor], | 
		
	
		
			
			|  |  |  | keep_prob: float = 1., | 
		
	
		
			
			|  |  |  | acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | 
		
	
		
			
			|  |  |  | activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | 
		
	
		
			
			|  |  |  | **kwargs) -> None: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | super().__init__(**kwargs) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | in_channels = int(in_channels) | 
		
	
		
			
			|  |  |  | out_channels = int(out_channels) | 
		
	
		
			
			|  |  |  | if not isinstance(data, Data) and not isinstance(data, PreparedData): | 
		
	
		
			
			|  |  |  | raise TypeError('data must be an instance of Data or PreparedData') | 
		
	
		
			
			|  |  |  | if not isinstance(relation_family, RelationFamily) and \ | 
		
	
		
			
			|  |  |  | not isinstance(relation_family, PreparedRelationFamily): | 
		
	
		
			
			|  |  |  | raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily') | 
		
	
		
			
			|  |  |  | if not isinstance(adjacency_matrices, list): | 
		
	
		
			
			|  |  |  | raise TypeError('adjacency_matrices must be a list') | 
		
	
		
			
			|  |  |  | if len(adjacency_matrices) == 0: | 
		
	
		
			
			|  |  |  | raise ValueError('adjacency_matrices must not be empty') | 
		
	
		
			
			|  |  |  | if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): | 
		
	
		
			
			|  |  |  | raise TypeError('adjacency_matrices elements must be of class torch.Tensor') | 
		
	
		
			
			|  |  |  | if not all(m.is_sparse for m in adjacency_matrices): | 
		
	
		
			
			|  |  |  | raise ValueError('adjacency_matrices elements must be sparse') | 
		
	
		
			
			|  |  |  | keep_prob = float(keep_prob) | 
		
	
		
			
			|  |  |  | if not isinstance(activation, types.FunctionType): | 
		
	
		
			
			|  |  |  | raise TypeError('activation must be a function') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | n_nodes_row = data.node_types[relation_family.node_type_row].count | 
		
	
		
			
			|  |  |  | n_nodes_column = data.node_types[relation_family.node_type_column].count | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.in_channels = in_channels | 
		
	
		
			
			|  |  |  | self.out_channels = out_channels | 
		
	
		
			
			|  |  |  | self.data = data | 
		
	
		
			
			|  |  |  | self.relation_family = relation_family | 
		
	
		
			
			|  |  |  | self.adjacency_matrices = adjacency_matrices | 
		
	
		
			
			|  |  |  | self.keep_prob = keep_prob | 
		
	
		
			
			|  |  |  | self.activation = activation | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.weight = torch.cat([ | 
		
	
		
			
			|  |  |  | init_glorot(in_channels, out_channels) \ | 
		
	
		
			
			|  |  |  | for _ in range(len(relation_family.relation_types)) | 
		
	
		
			
			|  |  |  | ], dim=1) | 
		
	
		
			
			|  |  |  | self.num_row_nodes = len(adjacency_matrices[0]) | 
		
	
		
			
			|  |  |  | self.num_relation_types = len(adjacency_matrices) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.weight_backward = torch.cat([ | 
		
	
		
			
			|  |  |  | self.weights = torch.cat([ | 
		
	
		
			
			|  |  |  | init_glorot(in_channels, out_channels) \ | 
		
	
		
			
			|  |  |  | for _ in range(len(relation_family.relation_types)) | 
		
	
		
			
			|  |  |  | for _ in range(self.num_relation_types) | 
		
	
		
			
			|  |  |  | ], dim=1) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.adjacency_matrix = _sparse_diag_cat([ | 
		
	
		
			
			|  |  |  | rel.adjacency_matrix \ | 
		
	
		
			
			|  |  |  | if rel.adjacency_matrix is not None \ | 
		
	
		
			
			|  |  |  | else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \ | 
		
	
		
			
			|  |  |  | for rel in relation_family.relation_types ]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.adjacency_matrix_backward = _sparse_diag_cat([ | 
		
	
		
			
			|  |  |  | rel.adjacency_matrix_backward \ | 
		
	
		
			
			|  |  |  | if rel.adjacency_matrix_backward is not None \ | 
		
	
		
			
			|  |  |  | else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \ | 
		
	
		
			
			|  |  |  | for rel in relation_family.relation_types ]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: | 
		
	
		
			
			|  |  |  | repr_row = prev_layer_repr[self.relation_family.node_type_row] | 
		
	
		
			
			|  |  |  | repr_column = prev_layer_repr[self.relation_family.node_type_column] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | new_repr_row = torch.sparse.mm(repr_column, self.weight) \ | 
		
	
		
			
			|  |  |  | if repr_column.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(repr_column, self.weight) | 
		
	
		
			
			|  |  |  | new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \ | 
		
	
		
			
			|  |  |  | if self.adjacency_matrix.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(self.adjacency_matrix, new_repr_row) | 
		
	
		
			
			|  |  |  | new_repr_row = new_repr_row.view(len(self.relation_family.relation_types), | 
		
	
		
			
			|  |  |  | len(repr_row), self.out_channels) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | new_repr_column = torch.sparse.mm(repr_row, self.weight) \ | 
		
	
		
			
			|  |  |  | if repr_row.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(repr_row, self.weight) | 
		
	
		
			
			|  |  |  | new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \ | 
		
	
		
			
			|  |  |  | if self.adjacency_matrix_backward.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(self.adjacency_matrix_backward, new_repr_column) | 
		
	
		
			
			|  |  |  | new_repr_column = new_repr_column.view(len(self.relation_family.relation_types), | 
		
	
		
			
			|  |  |  | len(repr_column), self.out_channels) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return (new_repr_row, new_repr_column) | 
		
	
		
			
			|  |  |  | def forward(self, x) -> torch.Tensor: | 
		
	
		
			
			|  |  |  | if self.keep_prob < 1.: | 
		
	
		
			
			|  |  |  | x = dropout(x, self.keep_prob) | 
		
	
		
			
			|  |  |  | res = torch.sparse.mm(x, self.weights) \ | 
		
	
		
			
			|  |  |  | if x.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(x, self.weights) | 
		
	
		
			
			|  |  |  | res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) | 
		
	
		
			
			|  |  |  | res = torch.cat(res) | 
		
	
		
			
			|  |  |  | res = torch.sparse.mm(self.adjacency_matrices, res) \ | 
		
	
		
			
			|  |  |  | if self.adjacency_matrices.is_sparse \ | 
		
	
		
			
			|  |  |  | else torch.mm(self.adjacency_matrices, res) | 
		
	
		
			
			|  |  |  | res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) | 
		
	
		
			
			|  |  |  | if self.activation is not None: | 
		
	
		
			
			|  |  |  | res = self.activation(res) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return res | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class FastConvLayer(torch.nn.Module): | 
		
	
	
		
			
				|  |  | 
 |