From 8edf8ce4f99d0df92b919d19957d1758665cdeaa Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 28 Jul 2020 12:17:00 +0200 Subject: [PATCH] Add test_fast_model_01(). --- src/icosagon/fastmodel.py | 9 ++++-- tests/icosagon/test_fastmodel.py | 50 ++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 tests/icosagon/test_fastmodel.py diff --git a/src/icosagon/fastmodel.py b/src/icosagon/fastmodel.py index 7c30906..a68fe58 100644 --- a/src/icosagon/fastmodel.py +++ b/src/icosagon/fastmodel.py @@ -1,12 +1,16 @@ from .fastconv import FastConvLayer from .bulkdec import BulkDecodeLayer from .input import OneHotInputLayer +from .trainprep import PreparedData import torch import types +from typing import List, \ + Union, \ + Callable class FastModel(torch.nn.Module): - def __init(self, prep_d: PreparedData, + def __init__(self, prep_d: PreparedData, layer_dimensions: List[int] = [32, 64], keep_prob: float = 1., rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, @@ -20,6 +24,7 @@ class FastModel(torch.nn.Module): layer_activation, dec_activation) self.prep_d = prep_d + self.layer_dimensions = layer_dimensions self.keep_prob = float(keep_prob) self.rel_activation = rel_activation self.layer_activation = layer_activation @@ -55,7 +60,7 @@ class FastModel(torch.nn.Module): def forward(self, _): return self.seq(None) - def self._check_params(self, prep_d, layer_dimensions, rel_activation, + def _check_params(self, prep_d, layer_dimensions, rel_activation, layer_activation, dec_activation): if not isinstance(prep_d, PreparedData): diff --git a/tests/icosagon/test_fastmodel.py b/tests/icosagon/test_fastmodel.py new file mode 100644 index 0000000..6d5948a --- /dev/null +++ b/tests/icosagon/test_fastmodel.py @@ -0,0 +1,50 @@ +from icosagon.fastmodel import FastModel +from icosagon.data import Data +from icosagon.trainprep import prepare_training, \ + TrainValTest +import torch +import time + + +def _make_symmetric(x: torch.Tensor): + x = (x + x.transpose(0, 1)) / 2 + return x + + +def _symmetric_random(n_rows, n_columns): + return _make_symmetric(torch.rand((n_rows, n_columns), + dtype=torch.float32).round().to_sparse()) + + +def _some_data_with_interactions(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + + fam = d.add_relation_family('Drug-Gene', 1, 0, True) + fam.add_relation_type('Target', + torch.rand((100, 1000), dtype=torch.float32).round().to_sparse()) + + fam = d.add_relation_family('Gene-Gene', 0, 0, True) + fam.add_relation_type('Interaction', + _symmetric_random(1000, 1000)) + + fam = d.add_relation_family('Drug-Drug', 1, 1, True) + for i in range(500): + fam.add_relation_type('Side Effect: Nausea %d' % i, + _symmetric_random(100, 100)) + fam.add_relation_type('Side Effect: Infertility %d' % i, + _symmetric_random(100, 100)) + fam.add_relation_type('Side Effect: Death %d' % i, + _symmetric_random(100, 100)) + return d + + +def test_fast_model_01(): + d = _some_data_with_interactions() + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) + model = FastModel(prep_d) + for i in range(10): + t = time.time() + _ = model(None) + print('Model forward took:', time.time() - t)