diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 70963ec..4505adf 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -24,33 +24,8 @@ def _equal(x: torch.Tensor, y: torch.Tensor): if not x.is_sparse: return (x == y) - # if x.shape != y.shape: - # return torch.tensor(0, dtype=torch.uint8) - return ((x - y).coalesce().values() == 0) - x = x.coalesce() - indices_x = np.empty(x.indices().shape[1], dtype=np.object) - indices_x[:] = list(map(tuple, x.indices().transpose(0, 1))) - order_x = np.argsort(indices_x) - #order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) - - y = y.coalesce() - indices_y = np.empty(y.indices().shape[1], dtype=np.object) - indices_y[:] = list(map(tuple, y.indices().transpose(0, 1))) - order_y = np.argsort(indices_y) - # order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) - - # print(indices_x.shape, indices_y.shape) - - if not len(indices_x) == len(indices_y): - return torch.tensor(0, dtype=torch.uint8) - - if not np.all(indices_x[order_x] == indices_y[order_y]): - return torch.tensor(0, dtype=torch.uint8) - - return (x.values()[order_x] == y.values()[order_y]) - @dataclass class NodeType(object):