|
@@ -112,3 +112,16 @@ def test_timing_04(): |
|
|
for _ in range(1300):
|
|
|
for _ in range(1300):
|
|
|
_ = torch.sparse.mm(adj_mat, rep)
|
|
|
_ = torch.sparse.mm(adj_mat, rep)
|
|
|
print('Elapsed:', time.time() - t)
|
|
|
print('Elapsed:', time.time() - t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_timing_05():
|
|
|
|
|
|
if torch.cuda.device_count() == 0:
|
|
|
|
|
|
pytest.skip('Test requires CUDA')
|
|
|
|
|
|
dev = torch.device('cuda:0')
|
|
|
|
|
|
adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse().to(dev)
|
|
|
|
|
|
rep = torch.eye(2000).requires_grad_(True).to(dev)
|
|
|
|
|
|
t = time.time()
|
|
|
|
|
|
for _ in range(1300):
|
|
|
|
|
|
_ = torch.sparse.mm(adj_mat, rep)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
print('Elapsed:', time.time() - t)
|