|  |  | @@ -24,33 +24,8 @@ def _equal(x: torch.Tensor, y: torch.Tensor): | 
		
	
		
			
			|  |  |  | if not x.is_sparse: | 
		
	
		
			
			|  |  |  | return (x == y) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | # if x.shape != y.shape: | 
		
	
		
			
			|  |  |  | # return torch.tensor(0, dtype=torch.uint8) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return ((x - y).coalesce().values() == 0) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | x = x.coalesce() | 
		
	
		
			
			|  |  |  | indices_x = np.empty(x.indices().shape[1], dtype=np.object) | 
		
	
		
			
			|  |  |  | indices_x[:] = list(map(tuple, x.indices().transpose(0, 1))) | 
		
	
		
			
			|  |  |  | order_x = np.argsort(indices_x) | 
		
	
		
			
			|  |  |  | #order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | y = y.coalesce() | 
		
	
		
			
			|  |  |  | indices_y = np.empty(y.indices().shape[1], dtype=np.object) | 
		
	
		
			
			|  |  |  | indices_y[:] = list(map(tuple, y.indices().transpose(0, 1))) | 
		
	
		
			
			|  |  |  | order_y = np.argsort(indices_y) | 
		
	
		
			
			|  |  |  | # order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | # print(indices_x.shape, indices_y.shape) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not len(indices_x) == len(indices_y): | 
		
	
		
			
			|  |  |  | return torch.tensor(0, dtype=torch.uint8) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not np.all(indices_x[order_x] == indices_y[order_y]): | 
		
	
		
			
			|  |  |  | return torch.tensor(0, dtype=torch.uint8) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return (x.values()[order_x] == y.values()[order_y]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | @dataclass | 
		
	
		
			
			|  |  |  | class NodeType(object): | 
		
	
	
		
			
				|  |  | 
 |