@@ -63,7 +63,7 @@ def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, | |||||
rows = torch.cat(rows) | rows = torch.cat(rows) | ||||
print('cat took:', time.time() - t) | print('cat took:', time.time() - t) | ||||
# print('rows:', rows) | # print('rows:', rows) | ||||
rows = set(rows.tolist()) | |||||
# rows = set(rows.tolist()) | |||||
# print('rows:', rows) | # print('rows:', rows) | ||||
t = time.time() | t = time.time() | ||||
@@ -71,7 +71,23 @@ def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, | |||||
indices = adj_mat.indices() | indices = adj_mat.indices() | ||||
values = adj_mat.values() | values = adj_mat.values() | ||||
print('indices[0]:', indices[0]) | print('indices[0]:', indices[0]) | ||||
print('indices[0][1]:', indices[0][1], indices[0][1] in rows) | |||||
# print('indices[0][1]:', indices[0][1], indices[0][1] in rows) | |||||
lookup = torch.zeros(row_vertex_count * num_relation_types, | |||||
dtype=torch.uint8, device=adj_mat.device) | |||||
lookup[rows] = 1 | |||||
values = values * lookup[indices[0]] | |||||
mask = torch.nonzero(values > 0, as_tuple=True)[0] | |||||
indices = indices[:, mask] | |||||
values = values[mask] | |||||
res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) | |||||
# res = res.coalesce() | |||||
print('res:', res) | |||||
print('"index_select()" took:', time.time() - t) | |||||
return res | |||||
selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) | selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) | ||||
# print('selection:', selection) | # print('selection:', selection) | ||||
selection = torch.nonzero(selection, as_tuple=True)[0] | selection = torch.nonzero(selection, as_tuple=True)[0] | ||||
@@ -78,6 +78,8 @@ def test_clear_adjacency_matrix_except_rows_03(): | |||||
def test_clear_adjacency_matrix_except_rows_04(): | def test_clear_adjacency_matrix_except_rows_04(): | ||||
adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) | adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8) | ||||
print('adj_mat.to_sparse():', adj_mat.to_sparse()) | |||||
t = time.time() | t = time.time() | ||||
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | ||||
print('_sparse_diag_cat() took:', time.time() - t) | print('_sparse_diag_cat() took:', time.time() - t) | ||||
@@ -93,3 +95,29 @@ def test_clear_adjacency_matrix_except_rows_04(): | |||||
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | ||||
assert _equal(res, truth).all() | assert _equal(res, truth).all() | ||||
def test_clear_adjacency_matrix_except_rows_05(): | |||||
if torch.cuda.device_count() == 0: | |||||
pytest.skip('Test requires CUDA') | |||||
device = torch.device('cuda:0') | |||||
adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device) | |||||
print('adj_mat.to_sparse():', adj_mat.to_sparse()) | |||||
t = time.time() | |||||
res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||||
print('_sparse_diag_cat() took:', time.time() - t) | |||||
rows = torch.tensor(list(range(512)), device=device) | |||||
t = time.time() | |||||
res = _clear_adjacency_matrix_except_rows(res, rows, | |||||
2000, 1300) | |||||
print('_clear_adjacency_matrix_except_rows() took:', time.time() - t) | |||||
adj_mat[512:] = torch.zeros(2000) | |||||
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||||
assert _equal(res, truth).all() |