IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

_clear_adjacency_matrix_except_rows() might yet be fast enough on the GPU.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
f8a02191cb
共有 2 個文件被更改,包括 46 次插入2 次删除
  1. +18
    -2
      src/triacontagon/util.py
  2. +28
    -0
      tests/triacontagon/test_util.py

+ 18
- 2
src/triacontagon/util.py 查看文件

@@ -63,7 +63,7 @@ def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
rows = torch.cat(rows)
print('cat took:', time.time() - t)
# print('rows:', rows)
rows = set(rows.tolist())
# rows = set(rows.tolist())
# print('rows:', rows)
t = time.time()
@@ -71,7 +71,23 @@ def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
indices = adj_mat.indices()
values = adj_mat.values()
print('indices[0]:', indices[0])
print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
# print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
lookup = torch.zeros(row_vertex_count * num_relation_types,
dtype=torch.uint8, device=adj_mat.device)
lookup[rows] = 1
values = values * lookup[indices[0]]
mask = torch.nonzero(values > 0, as_tuple=True)[0]
indices = indices[:, mask]
values = values[mask]
res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
# res = res.coalesce()
print('res:', res)
print('"index_select()" took:', time.time() - t)
return res
selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
# print('selection:', selection)
selection = torch.nonzero(selection, as_tuple=True)[0]


+ 28
- 0
tests/triacontagon/test_util.py 查看文件

@@ -78,6 +78,8 @@ def test_clear_adjacency_matrix_except_rows_03():
def test_clear_adjacency_matrix_except_rows_04():
adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
print('adj_mat.to_sparse():', adj_mat.to_sparse())
t = time.time()
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
print('_sparse_diag_cat() took:', time.time() - t)
@@ -93,3 +95,29 @@ def test_clear_adjacency_matrix_except_rows_04():
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
assert _equal(res, truth).all()
def test_clear_adjacency_matrix_except_rows_05():
if torch.cuda.device_count() == 0:
pytest.skip('Test requires CUDA')
device = torch.device('cuda:0')
adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device)
print('adj_mat.to_sparse():', adj_mat.to_sparse())
t = time.time()
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
print('_sparse_diag_cat() took:', time.time() - t)
rows = torch.tensor(list(range(512)), device=device)
t = time.time()
res = _clear_adjacency_matrix_except_rows(res, rows,
2000, 1300)
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
adj_mat[512:] = torch.zeros(2000)
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
assert _equal(res, truth).all()

Loading…
取消
儲存