|
@@ -1,10 +1,47 @@ |
|
|
from icosagon.fastconv import _sparse_diag_cat, \
|
|
|
from icosagon.fastconv import _sparse_diag_cat, \
|
|
|
_cat, \
|
|
|
_cat, \
|
|
|
FastGraphConv
|
|
|
|
|
|
|
|
|
FastGraphConv, \
|
|
|
|
|
|
FastConvLayer
|
|
|
from icosagon.data import _equal
|
|
|
from icosagon.data import _equal
|
|
|
import torch
|
|
|
import torch
|
|
|
import pdb
|
|
|
import pdb
|
|
|
import time
|
|
|
import time
|
|
|
|
|
|
from icosagon.data import Data
|
|
|
|
|
|
from icosagon.input import OneHotInputLayer
|
|
|
|
|
|
from icosagon.convlayer import DecagonLayer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_symmetric(x: torch.Tensor):
|
|
|
|
|
|
x = (x + x.transpose(0, 1)) / 2
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _symmetric_random(n_rows, n_columns):
|
|
|
|
|
|
return _make_symmetric(torch.rand((n_rows, n_columns),
|
|
|
|
|
|
dtype=torch.float32).round())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _some_data_with_interactions():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Gene', 1000)
|
|
|
|
|
|
d.add_node_type('Drug', 100)
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
|
|
|
|
|
|
fam.add_relation_type('Target',
|
|
|
|
|
|
torch.rand((100, 1000), dtype=torch.float32).round())
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
|
|
|
|
|
|
fam.add_relation_type('Interaction',
|
|
|
|
|
|
_symmetric_random(1000, 1000))
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
|
|
|
|
|
|
fam.add_relation_type('Side Effect: Nausea',
|
|
|
|
|
|
_symmetric_random(100, 100))
|
|
|
|
|
|
fam.add_relation_type('Side Effect: Infertility',
|
|
|
|
|
|
_symmetric_random(100, 100))
|
|
|
|
|
|
fam.add_relation_type('Side Effect: Death',
|
|
|
|
|
|
_symmetric_random(100, 100))
|
|
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sparse_diag_cat_01():
|
|
|
def test_sparse_diag_cat_01():
|
|
@@ -86,3 +123,51 @@ def test_fast_graph_conv_02(): |
|
|
t = time.time()
|
|
|
t = time.time()
|
|
|
_ = fgc(in_repr)
|
|
|
_ = fgc(in_repr)
|
|
|
print('FGC forward pass took:', time.time() - t)
|
|
|
print('FGC forward pass took:', time.time() - t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fast_graph_conv_03():
|
|
|
|
|
|
adj_mat = [
|
|
|
|
|
|
[ 0, 0, 1, 0, 1 ],
|
|
|
|
|
|
[ 0, 1, 0, 1, 0 ],
|
|
|
|
|
|
[ 1, 0, 1, 0, 0 ]
|
|
|
|
|
|
]
|
|
|
|
|
|
in_repr = torch.rand(5, 32)
|
|
|
|
|
|
adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
|
|
|
|
|
|
fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ])
|
|
|
|
|
|
out_repr = fgc(in_repr)
|
|
|
|
|
|
assert out_repr.shape == (1, 3, 64)
|
|
|
|
|
|
assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fast_graph_conv_04():
|
|
|
|
|
|
adj_mat = [
|
|
|
|
|
|
[ 0, 0, 1, 0, 1 ],
|
|
|
|
|
|
[ 0, 1, 0, 1, 0 ],
|
|
|
|
|
|
[ 1, 0, 1, 0, 0 ]
|
|
|
|
|
|
]
|
|
|
|
|
|
in_repr = torch.rand(5, 32)
|
|
|
|
|
|
adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
|
|
|
|
|
|
fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ])
|
|
|
|
|
|
out_repr = fgc(in_repr)
|
|
|
|
|
|
assert out_repr.shape == (2, 3, 64)
|
|
|
|
|
|
adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2)
|
|
|
|
|
|
adj_mat_1[0:3, 0:5] = adj_mat
|
|
|
|
|
|
adj_mat_1[3:6, 5:10] = adj_mat
|
|
|
|
|
|
res = torch.mm(in_repr, fgc.weights)
|
|
|
|
|
|
res = torch.split(res, res.shape[1] // 2, dim=1)
|
|
|
|
|
|
res = torch.cat(res)
|
|
|
|
|
|
res = torch.mm(adj_mat_1, res)
|
|
|
|
|
|
assert (res.view(2, 3, 64) == out_repr).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fast_conv_layer_01():
|
|
|
|
|
|
d = _some_data_with_interactions()
|
|
|
|
|
|
in_layer = OneHotInputLayer(d)
|
|
|
|
|
|
|
|
|
|
|
|
d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
|
|
|
|
|
|
seq_1 = torch.nn.Sequential(in_layer, d_layer)
|
|
|
|
|
|
out_repr_1 = seq_1(None)
|
|
|
|
|
|
|
|
|
|
|
|
conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
|
|
|
|
|
|
seq_2 = torch.nn.Sequential(in_layer, conv_layer)
|
|
|
|
|
|
out_repr_2 = seq_2(None)
|