From 915c5684900a4dc1d9c1c7afa2a4690d8bf2a4c2 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 11 Jun 2020 14:23:16 +0200 Subject: [PATCH] Small fix. --- src/icosagon/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/icosagon/data.py b/src/icosagon/data.py index b04432e..0690d29 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -31,6 +31,9 @@ def _equal(x: torch.Tensor, y: torch.Tensor): indices_y = list(map(tuple, y.indices().transpose(0, 1))) order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) + if not indices_x == indices_y: + return torch.tensor(0, dtype=torch.uint8) + return (x.values()[order_x] == y.values()[order_y])