IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_fast_graph_conv_03() and test_fast_graph_conv_04().

master
Stanislaw Adaszewski 3 years ago
parent
commit
c6f4b8779d
2 changed files with 97 additions and 6 deletions
  1. +11
    -5
      src/icosagon/fastconv.py
  2. +86
    -1
      tests/icosagon/test_fastconv.py

+ 11
- 5
src/icosagon/fastconv.py View File

@@ -195,7 +195,7 @@ class FastConvLayer(torch.nn.Module):
self.adjacency_matrix.append(adj_mat)
self.adjacency_matrix_backward.append(adj_mat_back)
self.weight.append(weight)
self.weight_back.append(weight_back)
self.weight_backward.append(weight_back)
def forward(self, prev_layer_repr):
for i, fam in enumerate(self.data.relation_families):
@@ -210,16 +210,22 @@ class FastConvLayer(torch.nn.Module):
x = torch.sparse.mm(x, self.weight[i]) \
if x.is_sparse \
else torch.mm(x, self.weight[i])
x = torch.sparse.mm(adj_mat, repr_row) \
x = torch.sparse.mm(adj_mat, x) \
if adj_mat.is_sparse \
else torch.mm(adj_mat, repr_row)
else torch.mm(adj_mat, x)
x = self.rel_activation(x)
x = x.view(len(fam.relation_types), len(repr_row), -1)
if adj_mat_back is not None:
x = torch.sparse.mm(adj_mat_back, repr_row) \
x = dropout(repr_column, keep_prob=self.keep_prob)
x = torch.sparse.mm(x, self.weight_backward[i]) \
if x.is_sparse \
else torch.mm(x, self.weight_backward[i])
x = torch.sparse.mm(adj_mat_back, x) \
if adj_mat_back.is_sparse \
else torch.mm(adj_mat_back, repr_row)
else torch.mm(adj_mat_back, x)
x = self.rel_activation(x)
x = x.view(len(fam.relation_types), len(repr_row), -1)
@staticmethod
def _check_params(input_dim, output_dim, data, keep_prob,


+ 86
- 1
tests/icosagon/test_fastconv.py View File

@@ -1,10 +1,47 @@
from icosagon.fastconv import _sparse_diag_cat, \
_cat, \
FastGraphConv
FastGraphConv, \
FastConvLayer
from icosagon.data import _equal
import torch
import pdb
import time
from icosagon.data import Data
from icosagon.input import OneHotInputLayer
from icosagon.convlayer import DecagonLayer
def _make_symmetric(x: torch.Tensor):
x = (x + x.transpose(0, 1)) / 2
return x
def _symmetric_random(n_rows, n_columns):
return _make_symmetric(torch.rand((n_rows, n_columns),
dtype=torch.float32).round())
def _some_data_with_interactions():
d = Data()
d.add_node_type('Gene', 1000)
d.add_node_type('Drug', 100)
fam = d.add_relation_family('Drug-Gene', 1, 0, True)
fam.add_relation_type('Target',
torch.rand((100, 1000), dtype=torch.float32).round())
fam = d.add_relation_family('Gene-Gene', 0, 0, True)
fam.add_relation_type('Interaction',
_symmetric_random(1000, 1000))
fam = d.add_relation_family('Drug-Drug', 1, 1, True)
fam.add_relation_type('Side Effect: Nausea',
_symmetric_random(100, 100))
fam.add_relation_type('Side Effect: Infertility',
_symmetric_random(100, 100))
fam.add_relation_type('Side Effect: Death',
_symmetric_random(100, 100))
return d
def test_sparse_diag_cat_01():
@@ -86,3 +123,51 @@ def test_fast_graph_conv_02():
t = time.time()
_ = fgc(in_repr)
print('FGC forward pass took:', time.time() - t)
def test_fast_graph_conv_03():
adj_mat = [
[ 0, 0, 1, 0, 1 ],
[ 0, 1, 0, 1, 0 ],
[ 1, 0, 1, 0, 0 ]
]
in_repr = torch.rand(5, 32)
adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ])
out_repr = fgc(in_repr)
assert out_repr.shape == (1, 3, 64)
assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all()
def test_fast_graph_conv_04():
adj_mat = [
[ 0, 0, 1, 0, 1 ],
[ 0, 1, 0, 1, 0 ],
[ 1, 0, 1, 0, 0 ]
]
in_repr = torch.rand(5, 32)
adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ])
out_repr = fgc(in_repr)
assert out_repr.shape == (2, 3, 64)
adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2)
adj_mat_1[0:3, 0:5] = adj_mat
adj_mat_1[3:6, 5:10] = adj_mat
res = torch.mm(in_repr, fgc.weights)
res = torch.split(res, res.shape[1] // 2, dim=1)
res = torch.cat(res)
res = torch.mm(adj_mat_1, res)
assert (res.view(2, 3, 64) == out_repr).all()
def test_fast_conv_layer_01():
d = _some_data_with_interactions()
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
seq_1 = torch.nn.Sequential(in_layer, d_layer)
out_repr_1 = seq_1(None)
conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
seq_2 = torch.nn.Sequential(in_layer, conv_layer)
out_repr_2 = seq_2(None)

Loading…
Cancel
Save