from triacontagon.cumcount import dfill, \ argunsort, \ cumcount import torch import numpy as np def test_dfill_01(): input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]) output = dfill(input) expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12]) assert torch.all(output == expected) def test_argunsort_01(): input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) output = np.argsort(input.numpy()) output = argunsort(torch.tensor(output)) expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11]) assert torch.all(output == expected) def test_cumcount_01(): input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4]) output = cumcount(input) expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1]) assert torch.all(output == expected)