diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 7e39db7..70963ec 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -14,6 +14,7 @@ from typing import List, \ Type from .decode import DEDICOMDecoder, \ BilinearDecoder +import numpy as np def _equal(x: torch.Tensor, y: torch.Tensor): @@ -23,15 +24,29 @@ def _equal(x: torch.Tensor, y: torch.Tensor): if not x.is_sparse: return (x == y) + # if x.shape != y.shape: + # return torch.tensor(0, dtype=torch.uint8) + + return ((x - y).coalesce().values() == 0) + x = x.coalesce() - indices_x = list(map(tuple, x.indices().transpose(0, 1))) - order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) + indices_x = np.empty(x.indices().shape[1], dtype=np.object) + indices_x[:] = list(map(tuple, x.indices().transpose(0, 1))) + order_x = np.argsort(indices_x) + #order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) y = y.coalesce() - indices_y = list(map(tuple, y.indices().transpose(0, 1))) - order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) + indices_y = np.empty(y.indices().shape[1], dtype=np.object) + indices_y[:] = list(map(tuple, y.indices().transpose(0, 1))) + order_y = np.argsort(indices_y) + # order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) + + # print(indices_x.shape, indices_y.shape) + + if not len(indices_x) == len(indices_y): + return torch.tensor(0, dtype=torch.uint8) - if not indices_x == indices_y: + if not np.all(indices_x[order_x] == indices_y[order_y]): return torch.tensor(0, dtype=torch.uint8) return (x.values()[order_x] == y.values()[order_y]) diff --git a/tests/icosagon/test_data.py b/tests/icosagon/test_data.py index 08ded89..57060e9 100644 --- a/tests/icosagon/test_data.py +++ b/tests/icosagon/test_data.py @@ -16,11 +16,15 @@ def test_equal_01(): x = torch.rand((10, 10)) y = torch.rand((10, 10)).round().to_sparse() + print('x == x ?') assert torch.all(_equal(x, x)) + print('y == y ?') assert torch.all(_equal(y, y)) + print('x == y ?') with pytest.raises(ValueError): _equal(x, y) + print('y == z ?') z = torch.rand((10, 10)).round().to_sparse() assert not torch.all(_equal(y, z)) @@ -71,7 +75,7 @@ def test_relation_family_03(): d.add_node_type('B', 5) fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder) - + fam.add_relation_type('A-B', torch.rand((10, 5)).round()) assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \