from icosagon.fastconv import _sparse_diag_cat, \ _cat, \ FastGraphConv, \ FastConvLayer from icosagon.data import _equal import torch import pdb import time from icosagon.data import Data from icosagon.input import OneHotInputLayer from icosagon.convlayer import DecagonLayer def _make_symmetric(x: torch.Tensor): x = (x + x.transpose(0, 1)) / 2 return x def _symmetric_random(n_rows, n_columns): return _make_symmetric(torch.rand((n_rows, n_columns), dtype=torch.float32).round().to_sparse()) def _some_data_with_interactions(): d = Data() d.add_node_type('Gene', 1000) d.add_node_type('Drug', 100) fam = d.add_relation_family('Drug-Gene', 1, 0, True) fam.add_relation_type('Target', torch.rand((100, 1000), dtype=torch.float32).round().to_sparse()) fam = d.add_relation_family('Gene-Gene', 0, 0, True) fam.add_relation_type('Interaction', _symmetric_random(1000, 1000)) fam = d.add_relation_family('Drug-Drug', 1, 1, True) fam.add_relation_type('Side Effect: Nausea', _symmetric_random(100, 100)) fam.add_relation_type('Side Effect: Infertility', _symmetric_random(100, 100)) fam.add_relation_type('Side Effect: Death', _symmetric_random(100, 100)) return d def test_sparse_diag_cat_01(): matrices = [ torch.rand(5, 10).round() for _ in range(7) ] ground_truth = torch.zeros(35, 70) ground_truth[0:5, 0:10] = matrices[0] ground_truth[5:10, 10:20] = matrices[1] ground_truth[10:15, 20:30] = matrices[2] ground_truth[15:20, 30:40] = matrices[3] ground_truth[20:25, 40:50] = matrices[4] ground_truth[25:30, 50:60] = matrices[5] ground_truth[30:35, 60:70] = matrices[6] res = _sparse_diag_cat([ m.to_sparse() for m in matrices ]) res = res.to_dense() assert torch.all(res == ground_truth) def test_sparse_diag_cat_02(): x = [ torch.rand(5, 10).round() for _ in range(7) ] a = [ m.to_sparse() for m in x ] a = _sparse_diag_cat(a) b = torch.rand(70, 64) res = torch.sparse.mm(a, b) ground_truth = torch.zeros(35, 64) ground_truth[0:5, :] = torch.mm(x[0], b[0:10]) ground_truth[5:10, :] = torch.mm(x[1], b[10:20]) ground_truth[10:15, :] = torch.mm(x[2], b[20:30]) ground_truth[15:20, :] = torch.mm(x[3], b[30:40]) ground_truth[20:25, :] = torch.mm(x[4], b[40:50]) ground_truth[25:30, :] = torch.mm(x[5], b[50:60]) ground_truth[30:35, :] = torch.mm(x[6], b[60:70]) assert torch.all(res == ground_truth) def test_cat_01(): matrices = [ torch.rand(5, 10) for _ in range(7) ] res = _cat(matrices) assert res.shape == (35, 10) assert not res.is_sparse ground_truth = torch.zeros(35, 10) for i in range(7): ground_truth[i*5:(i+1)*5, :] = matrices[i] assert torch.all(res == ground_truth) def test_cat_02(): matrices = [ torch.rand(5, 10) for _ in range(7) ] ground_truth = torch.zeros(35, 10) for i in range(7): ground_truth[i*5:(i+1)*5, :] = matrices[i] res = _cat([ m.to_sparse() for m in matrices ]) assert res.shape == (35, 10) assert res.is_sparse assert torch.all(res.to_dense() == ground_truth) def test_fast_graph_conv_01(): # pdb.set_trace() adj_mats = [ torch.rand(10, 15).round().to_sparse() \ for _ in range(23) ] fgc = FastGraphConv(32, 64, adj_mats) in_repr = torch.rand(15, 32) _ = fgc(in_repr) def test_fast_graph_conv_02(): t = time.time() m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() adj_mats = [ m for _ in range(1300) ] print('Generating adj_mats took:', time.time() - t) t = time.time() fgc = FastGraphConv(32, 64, adj_mats) print('FGC constructor took:', time.time() - t) in_repr = torch.rand(2000, 32) for _ in range(3): t = time.time() _ = fgc(in_repr) print('FGC forward pass took:', time.time() - t) def test_fast_graph_conv_03(): adj_mat = [ [ 0, 0, 1, 0, 1 ], [ 0, 1, 0, 1, 0 ], [ 1, 0, 1, 0, 0 ] ] in_repr = torch.rand(5, 32) adj_mat = torch.tensor(adj_mat, dtype=torch.float32) fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ]) out_repr = fgc(in_repr) assert out_repr.shape == (1, 3, 64) assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all() def test_fast_graph_conv_04(): adj_mat = [ [ 0, 0, 1, 0, 1 ], [ 0, 1, 0, 1, 0 ], [ 1, 0, 1, 0, 0 ] ] in_repr = torch.rand(5, 32) adj_mat = torch.tensor(adj_mat, dtype=torch.float32) fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ]) out_repr = fgc(in_repr) assert out_repr.shape == (2, 3, 64) adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2) adj_mat_1[0:3, 0:5] = adj_mat adj_mat_1[3:6, 5:10] = adj_mat res = torch.mm(in_repr, fgc.weights) res = torch.split(res, res.shape[1] // 2, dim=1) res = torch.cat(res) res = torch.mm(adj_mat_1, res) assert (res.view(2, 3, 64) == out_repr).all() def test_fast_conv_layer_01(): d = _some_data_with_interactions() in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) seq_1 = torch.nn.Sequential(in_layer, d_layer) _ = seq_1(None) conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) seq_2 = torch.nn.Sequential(in_layer, conv_layer) _ = seq_2(None) def test_fast_conv_layer_02(): d = _some_data_with_interactions() in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) seq_1 = torch.nn.Sequential(in_layer, d_layer) out_repr_1 = seq_1(None) assert len(d_layer.next_layer_repr[0]) == 2 assert len(d_layer.next_layer_repr[1]) == 2 conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) assert len(conv_layer.next_layer_repr[1]) == 2 conv_layer.next_layer_repr[1][0].weights = torch.cat([ d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight, ], dim=1) conv_layer.next_layer_repr[1][1].weights = torch.cat([ d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight, d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight, d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight, ], dim=1) assert len(conv_layer.next_layer_repr[0]) == 2 conv_layer.next_layer_repr[0][0].weights = torch.cat([ d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight, ], dim=1) conv_layer.next_layer_repr[0][1].weights = torch.cat([ d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight, ], dim=1) seq_2 = torch.nn.Sequential(in_layer, conv_layer) out_repr_2 = seq_2(None) assert len(out_repr_1) == len(out_repr_2) for i in range(len(out_repr_1)): assert torch.all(out_repr_1[i] == out_repr_2[i])