| @@ -24,33 +24,8 @@ def _equal(x: torch.Tensor, y: torch.Tensor): | |||||
| if not x.is_sparse: | if not x.is_sparse: | ||||
| return (x == y) | return (x == y) | ||||
| # if x.shape != y.shape: | |||||
| # return torch.tensor(0, dtype=torch.uint8) | |||||
| return ((x - y).coalesce().values() == 0) | return ((x - y).coalesce().values() == 0) | ||||
| x = x.coalesce() | |||||
| indices_x = np.empty(x.indices().shape[1], dtype=np.object) | |||||
| indices_x[:] = list(map(tuple, x.indices().transpose(0, 1))) | |||||
| order_x = np.argsort(indices_x) | |||||
| #order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) | |||||
| y = y.coalesce() | |||||
| indices_y = np.empty(y.indices().shape[1], dtype=np.object) | |||||
| indices_y[:] = list(map(tuple, y.indices().transpose(0, 1))) | |||||
| order_y = np.argsort(indices_y) | |||||
| # order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) | |||||
| # print(indices_x.shape, indices_y.shape) | |||||
| if not len(indices_x) == len(indices_y): | |||||
| return torch.tensor(0, dtype=torch.uint8) | |||||
| if not np.all(indices_x[order_x] == indices_y[order_y]): | |||||
| return torch.tensor(0, dtype=torch.uint8) | |||||
| return (x.values()[order_x] == y.values()[order_y]) | |||||
| @dataclass | @dataclass | ||||
| class NodeType(object): | class NodeType(object): | ||||