IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Parcourir la source

Make _equal() much faster.

master
Stanislaw Adaszewski il y a 4 ans
Parent
révision
d4dab52e82
2 fichiers modifiés avec 25 ajouts et 6 suppressions
  1. +20
    -5
      src/icosagon/data.py
  2. +5
    -1
      tests/icosagon/test_data.py

+ 20
- 5
src/icosagon/data.py Voir le fichier

@@ -14,6 +14,7 @@ from typing import List, \
Type
from .decode import DEDICOMDecoder, \
BilinearDecoder
import numpy as np
def _equal(x: torch.Tensor, y: torch.Tensor):
@@ -23,15 +24,29 @@ def _equal(x: torch.Tensor, y: torch.Tensor):
if not x.is_sparse:
return (x == y)
# if x.shape != y.shape:
# return torch.tensor(0, dtype=torch.uint8)
return ((x - y).coalesce().values() == 0)
x = x.coalesce()
indices_x = list(map(tuple, x.indices().transpose(0, 1)))
order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx])
indices_x = np.empty(x.indices().shape[1], dtype=np.object)
indices_x[:] = list(map(tuple, x.indices().transpose(0, 1)))
order_x = np.argsort(indices_x)
#order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx])
y = y.coalesce()
indices_y = list(map(tuple, y.indices().transpose(0, 1)))
order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
indices_y = np.empty(y.indices().shape[1], dtype=np.object)
indices_y[:] = list(map(tuple, y.indices().transpose(0, 1)))
order_y = np.argsort(indices_y)
# order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
# print(indices_x.shape, indices_y.shape)
if not len(indices_x) == len(indices_y):
return torch.tensor(0, dtype=torch.uint8)
if not indices_x == indices_y:
if not np.all(indices_x[order_x] == indices_y[order_y]):
return torch.tensor(0, dtype=torch.uint8)
return (x.values()[order_x] == y.values()[order_y])


+ 5
- 1
tests/icosagon/test_data.py Voir le fichier

@@ -16,11 +16,15 @@ def test_equal_01():
x = torch.rand((10, 10))
y = torch.rand((10, 10)).round().to_sparse()
print('x == x ?')
assert torch.all(_equal(x, x))
print('y == y ?')
assert torch.all(_equal(y, y))
print('x == y ?')
with pytest.raises(ValueError):
_equal(x, y)
print('y == z ?')
z = torch.rand((10, 10)).round().to_sparse()
assert not torch.all(_equal(y, z))
@@ -71,7 +75,7 @@ def test_relation_family_03():
d.add_node_type('B', 5)
fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder)
fam.add_relation_type('A-B', torch.rand((10, 5)).round())
assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \


Chargement…
Annuler
Enregistrer