diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index 2367b06..e6fefc8 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -63,7 +63,7 @@ def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, rows = torch.cat(rows) print('cat took:', time.time() - t) # print('rows:', rows) - rows = set(rows.tolist()) + # rows = set(rows.tolist()) # print('rows:', rows) t = time.time() @@ -71,7 +71,23 @@ def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, indices = adj_mat.indices() values = adj_mat.values() print('indices[0]:', indices[0]) - print('indices[0][1]:', indices[0][1], indices[0][1] in rows) + # print('indices[0][1]:', indices[0][1], indices[0][1] in rows) + + lookup = torch.zeros(row_vertex_count * num_relation_types, + dtype=torch.uint8, device=adj_mat.device) + lookup[rows] = 1 + values = values * lookup[indices[0]] + mask = torch.nonzero(values > 0, as_tuple=True)[0] + indices = indices[:, mask] + values = values[mask] + + res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) + # res = res.coalesce() + print('res:', res) + print('"index_select()" took:', time.time() - t) + + return res + selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) # print('selection:', selection) selection = torch.nonzero(selection, as_tuple=True)[0] diff --git a/tests/triacontagon/test_util.py b/tests/triacontagon/test_util.py index 5937535..e4f7f9d 100644 --- a/tests/triacontagon/test_util.py +++ b/tests/triacontagon/test_util.py @@ -78,6 +78,8 @@ def test_clear_adjacency_matrix_except_rows_03(): def test_clear_adjacency_matrix_except_rows_04(): adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) + print('adj_mat.to_sparse():', adj_mat.to_sparse()) + t = time.time() res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) print('_sparse_diag_cat() took:', time.time() - t) @@ -93,3 +95,29 @@ def test_clear_adjacency_matrix_except_rows_04(): truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) assert _equal(res, truth).all() + + +def test_clear_adjacency_matrix_except_rows_05(): + if torch.cuda.device_count() == 0: + pytest.skip('Test requires CUDA') + + device = torch.device('cuda:0') + adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device) + + print('adj_mat.to_sparse():', adj_mat.to_sparse()) + + t = time.time() + res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) + print('_sparse_diag_cat() took:', time.time() - t) + + rows = torch.tensor(list(range(512)), device=device) + + t = time.time() + res = _clear_adjacency_matrix_except_rows(res, rows, + 2000, 1300) + print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) + + adj_mat[512:] = torch.zeros(2000) + truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) + + assert _equal(res, truth).all()