|
- from triacontagon.util import \
- _clear_adjacency_matrix_except_rows, \
- _sparse_diag_cat, \
- _equal
- import torch
- import time
-
-
- def test_clear_adjacency_matrix_except_rows_01():
- adj_mat = torch.tensor([
- [0, 0, 1, 0, 0],
- [0, 0, 0, 1, 1],
- [1, 0, 1, 0, 0],
- [1, 1, 0, 0, 0]
- ], dtype=torch.uint8).to_sparse()
-
- adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ])
-
- res = _clear_adjacency_matrix_except_rows(adj_mat,
- torch.tensor([1, 3]), 4, 2)
-
- res = res.to_dense()
-
- truth = torch.tensor([
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
- ], dtype=torch.uint8)
-
- print('res:', res)
-
- assert torch.all(res == truth)
-
-
- def test_clear_adjacency_matrix_except_rows_02():
- adj_mat = torch.rand(6, 10).round().to(torch.uint8)
-
- t = time.time()
- res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
- print('_sparse_diag_cat() took:', time.time() - t)
-
- t = time.time()
- res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
- 6, 130)
- print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
-
- adj_mat[0] = adj_mat[2] = adj_mat[4] = \
- torch.zeros(10)
- truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
-
- assert _equal(res, truth).all()
-
-
- def test_clear_adjacency_matrix_except_rows_03():
- adj_mat = torch.rand(6, 10).round().to(torch.uint8)
-
- t = time.time()
- res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
- print('_sparse_diag_cat() took:', time.time() - t)
-
- t = time.time()
- res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
- 6, 1300)
- print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
-
- adj_mat[0] = adj_mat[2] = adj_mat[4] = \
- torch.zeros(10)
- truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
-
- assert _equal(res, truth).all()
-
-
- def test_clear_adjacency_matrix_except_rows_04():
- adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
-
- t = time.time()
- res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
- print('_sparse_diag_cat() took:', time.time() - t)
-
- t = time.time()
- res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
- 2000, 1300)
- print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
-
- adj_mat[0] = adj_mat[2] = adj_mat[4] = \
- torch.zeros(2000)
- adj_mat[6:] = torch.zeros(2000)
- truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
-
- assert _equal(res, truth).all()
|