|
@@ -31,6 +31,9 @@ def _equal(x: torch.Tensor, y: torch.Tensor): |
|
|
indices_y = list(map(tuple, y.indices().transpose(0, 1)))
|
|
|
indices_y = list(map(tuple, y.indices().transpose(0, 1)))
|
|
|
order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
|
|
|
order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
|
|
|
|
|
|
|
|
|
|
|
|
if not indices_x == indices_y:
|
|
|
|
|
|
return torch.tensor(0, dtype=torch.uint8)
|
|
|
|
|
|
|
|
|
return (x.values()[order_x] == y.values()[order_y])
|
|
|
return (x.values()[order_x] == y.values()[order_y])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|