|
- from icosagon.fastconv import _sparse_diag_cat, \
- _cat, \
- FastGraphConv
- from icosagon.data import _equal
- import torch
- import pdb
- import time
-
-
- def test_sparse_diag_cat_01():
- matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
- ground_truth = torch.zeros(35, 70)
- ground_truth[0:5, 0:10] = matrices[0]
- ground_truth[5:10, 10:20] = matrices[1]
- ground_truth[10:15, 20:30] = matrices[2]
- ground_truth[15:20, 30:40] = matrices[3]
- ground_truth[20:25, 40:50] = matrices[4]
- ground_truth[25:30, 50:60] = matrices[5]
- ground_truth[30:35, 60:70] = matrices[6]
- res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
- res = res.to_dense()
- assert torch.all(res == ground_truth)
-
-
- def test_sparse_diag_cat_02():
- x = [ torch.rand(5, 10).round() for _ in range(7) ]
- a = [ m.to_sparse() for m in x ]
- a = _sparse_diag_cat(a)
- b = torch.rand(70, 64)
- res = torch.sparse.mm(a, b)
-
- ground_truth = torch.zeros(35, 64)
- ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
- ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
- ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
- ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
- ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
- ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
- ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
-
- assert torch.all(res == ground_truth)
-
-
- def test_cat_01():
- matrices = [ torch.rand(5, 10) for _ in range(7) ]
- res = _cat(matrices)
- assert res.shape == (35, 10)
- assert not res.is_sparse
- ground_truth = torch.zeros(35, 10)
- for i in range(7):
- ground_truth[i*5:(i+1)*5, :] = matrices[i]
- assert torch.all(res == ground_truth)
-
-
- def test_cat_02():
- matrices = [ torch.rand(5, 10) for _ in range(7) ]
- ground_truth = torch.zeros(35, 10)
- for i in range(7):
- ground_truth[i*5:(i+1)*5, :] = matrices[i]
- res = _cat([ m.to_sparse() for m in matrices ])
- assert res.shape == (35, 10)
- assert res.is_sparse
- assert torch.all(res.to_dense() == ground_truth)
-
-
- def test_fast_graph_conv_01():
- # pdb.set_trace()
- adj_mats = [ torch.rand(10, 15).round().to_sparse() \
- for _ in range(23) ]
- fgc = FastGraphConv(32, 64, adj_mats)
- in_repr = torch.rand(15, 32)
- _ = fgc(in_repr)
-
-
- def test_fast_graph_conv_02():
- t = time.time()
- m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
- adj_mats = [ m for _ in range(1300) ]
- print('Generating adj_mats took:', time.time() - t)
- t = time.time()
- fgc = FastGraphConv(32, 64, adj_mats)
- print('FGC constructor took:', time.time() - t)
- in_repr = torch.rand(2000, 32)
-
- for _ in range(3):
- t = time.time()
- _ = fgc(in_repr)
- print('FGC forward pass took:', time.time() - t)
|