| 
							- from icosagon.fastconv import _sparse_diag_cat, \
 -     _cat, \
 -     FastGraphConv, \
 -     FastConvLayer
 - from icosagon.data import _equal
 - import torch
 - import pdb
 - import time
 - from icosagon.data import Data
 - from icosagon.input import OneHotInputLayer
 - from icosagon.convlayer import DecagonLayer
 - 
 - 
 - def _make_symmetric(x: torch.Tensor):
 -     x = (x + x.transpose(0, 1)) / 2
 -     return x
 - 
 - 
 - def _symmetric_random(n_rows, n_columns):
 -     return _make_symmetric(torch.rand((n_rows, n_columns),
 -         dtype=torch.float32).round())
 - 
 - 
 - def _some_data_with_interactions():
 -     d = Data()
 -     d.add_node_type('Gene', 1000)
 -     d.add_node_type('Drug', 100)
 - 
 -     fam = d.add_relation_family('Drug-Gene', 1, 0, True)
 -     fam.add_relation_type('Target',
 -         torch.rand((100, 1000), dtype=torch.float32).round())
 - 
 -     fam = d.add_relation_family('Gene-Gene', 0, 0, True)
 -     fam.add_relation_type('Interaction',
 -         _symmetric_random(1000, 1000))
 - 
 -     fam = d.add_relation_family('Drug-Drug', 1, 1, True)
 -     fam.add_relation_type('Side Effect: Nausea',
 -         _symmetric_random(100, 100))
 -     fam.add_relation_type('Side Effect: Infertility',
 -         _symmetric_random(100, 100))
 -     fam.add_relation_type('Side Effect: Death',
 -         _symmetric_random(100, 100))
 -     return d
 - 
 - 
 - def test_sparse_diag_cat_01():
 -     matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
 -     ground_truth = torch.zeros(35, 70)
 -     ground_truth[0:5, 0:10] = matrices[0]
 -     ground_truth[5:10, 10:20] = matrices[1]
 -     ground_truth[10:15, 20:30] = matrices[2]
 -     ground_truth[15:20, 30:40] = matrices[3]
 -     ground_truth[20:25, 40:50] = matrices[4]
 -     ground_truth[25:30, 50:60] = matrices[5]
 -     ground_truth[30:35, 60:70] = matrices[6]
 -     res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
 -     res = res.to_dense()
 -     assert torch.all(res == ground_truth)
 - 
 - 
 - def test_sparse_diag_cat_02():
 -     x = [ torch.rand(5, 10).round() for _ in range(7) ]
 -     a = [ m.to_sparse() for m in x ]
 -     a = _sparse_diag_cat(a)
 -     b = torch.rand(70, 64)
 -     res = torch.sparse.mm(a, b)
 - 
 -     ground_truth = torch.zeros(35, 64)
 -     ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
 -     ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
 -     ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
 -     ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
 -     ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
 -     ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
 -     ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
 - 
 -     assert torch.all(res == ground_truth)
 - 
 - 
 - def test_cat_01():
 -     matrices = [ torch.rand(5, 10) for _ in range(7) ]
 -     res = _cat(matrices)
 -     assert res.shape == (35, 10)
 -     assert not res.is_sparse
 -     ground_truth = torch.zeros(35, 10)
 -     for i in range(7):
 -         ground_truth[i*5:(i+1)*5, :] = matrices[i]
 -     assert torch.all(res == ground_truth)
 - 
 - 
 - def test_cat_02():
 -     matrices = [ torch.rand(5, 10) for _ in range(7) ]
 -     ground_truth = torch.zeros(35, 10)
 -     for i in range(7):
 -         ground_truth[i*5:(i+1)*5, :] = matrices[i]
 -     res = _cat([ m.to_sparse() for m in matrices ])
 -     assert res.shape == (35, 10)
 -     assert res.is_sparse
 -     assert torch.all(res.to_dense() == ground_truth)
 - 
 - 
 - def test_fast_graph_conv_01():
 -     # pdb.set_trace()
 -     adj_mats = [ torch.rand(10, 15).round().to_sparse() \
 -         for _ in range(23) ]
 -     fgc = FastGraphConv(32, 64, adj_mats)
 -     in_repr = torch.rand(15, 32)
 -     _ = fgc(in_repr)
 - 
 - 
 - def test_fast_graph_conv_02():
 -     t = time.time()
 -     m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
 -     adj_mats = [ m for _ in range(1300) ]
 -     print('Generating adj_mats took:', time.time() - t)
 -     t = time.time()
 -     fgc = FastGraphConv(32, 64, adj_mats)
 -     print('FGC constructor took:', time.time() - t)
 -     in_repr = torch.rand(2000, 32)
 - 
 -     for _ in range(3):
 -         t = time.time()
 -         _ = fgc(in_repr)
 -         print('FGC forward pass took:', time.time() - t)
 - 
 - 
 - def test_fast_graph_conv_03():
 -     adj_mat = [
 -         [ 0, 0, 1, 0, 1 ],
 -         [ 0, 1, 0, 1, 0 ],
 -         [ 1, 0, 1, 0, 0 ]
 -     ]
 -     in_repr = torch.rand(5, 32)
 -     adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
 -     fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ])
 -     out_repr = fgc(in_repr)
 -     assert out_repr.shape == (1, 3, 64)
 -     assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all()
 - 
 - 
 - def test_fast_graph_conv_04():
 -     adj_mat = [
 -         [ 0, 0, 1, 0, 1 ],
 -         [ 0, 1, 0, 1, 0 ],
 -         [ 1, 0, 1, 0, 0 ]
 -     ]
 -     in_repr = torch.rand(5, 32)
 -     adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
 -     fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ])
 -     out_repr = fgc(in_repr)
 -     assert out_repr.shape == (2, 3, 64)
 -     adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2)
 -     adj_mat_1[0:3, 0:5] = adj_mat
 -     adj_mat_1[3:6, 5:10] = adj_mat
 -     res = torch.mm(in_repr, fgc.weights)
 -     res = torch.split(res, res.shape[1] // 2, dim=1)
 -     res = torch.cat(res)
 -     res = torch.mm(adj_mat_1, res)
 -     assert (res.view(2, 3, 64) == out_repr).all()
 - 
 - 
 - def test_fast_conv_layer_01():
 -     d = _some_data_with_interactions()
 -     in_layer = OneHotInputLayer(d)
 - 
 -     d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
 -     seq_1 = torch.nn.Sequential(in_layer, d_layer)
 -     out_repr_1 = seq_1(None)
 - 
 -     conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
 -     seq_2 = torch.nn.Sequential(in_layer, conv_layer)
 -     out_repr_2 = seq_2(None)
 
 
  |