From ebcd38c91020e26421b7586bea04622c695c2111 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Tue, 18 Aug 2020 17:44:57 +0200 Subject: [PATCH] Add first triacontagon experiment. --- .../triacontagon_run/triacontagon_run.py | 146 ++++++++++++++++++ src/triacontagon/util.py | 3 +- 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 experiments/triacontagon_run/triacontagon_run.py diff --git a/experiments/triacontagon_run/triacontagon_run.py b/experiments/triacontagon_run/triacontagon_run.py new file mode 100644 index 0000000..cfa81f0 --- /dev/null +++ b/experiments/triacontagon_run/triacontagon_run.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +from triacontagon.data import Data +from triacontagon.split import split_data +from triacontagon.model import Model +from triacontagon.loop import TrainLoop +from triacontagon.decode import dedicom_decoder +from triacontagon.util import common_one_hot_encoding + +import os +import pandas as pd +from bisect import bisect_left +import torch +import sys + + +def index(a, x): + i = bisect_left(a, x) + if i != len(a) and a[i] == x: + return i + raise ValueError + + +def load_data(dev): + path = '/pstore/data/data_science/ref/decagon' + + df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv')) + df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv')) + df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv')) + df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv')) + df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv')) + df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv')) + + lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ] + + for nam in lst: + print(f'len({nam}): {len(locals()[nam])}') + print(f'{nam}.columns: {locals()[nam].columns}') + + genes = set() + genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \ + .union(df_tgtall['Gene']).union(df_tgt['Gene']) + genes = sorted(genes) + print('len(genes):', len(genes)) + + drugs = set() + drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \ + .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH']) + drugs = sorted(drugs) + print('len(drugs):', len(drugs)) + + data = Data() + data.add_vertex_type('Gene', len(genes)) + data.add_vertex_type('Drug', len(drugs)) + + print('Preparing PPI...') + print('Indexing rows...') + rows = [index(genes, g) for g in df_ppi['Gene 1']] + print('Indexing cols...') + cols = [index(genes, g) for g in df_ppi['Gene 2']] + indices = list(zip(rows, cols)) + indices = torch.tensor(indices).transpose(0, 1) + values = torch.ones(len(rows)) + print('indices.shape:', indices.shape, 'values.shape:', values.shape) + adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2, + device=dev) + adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2 + print('adj_mat created') + data.add_edge_type('PPI', 0, 0, [ adj_mat ], dedicom_decoder) + print('OK') + + print('Preparing Drug-Gene (Target) edges...') + rows = [index(drugs, d) for d in df_tgtall['STITCH']] + cols = [index(genes, g) for g in df_tgtall['Gene']] + indices = list(zip(rows, cols)) + indices = torch.tensor(indices).transpose(0, 1) + values = torch.ones(len(rows)) + adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)), + device=dev) + data.add_edge_type('Drug-Gene', 1, 0, [ adj_mat ], dedicom_decoder) + data.add_edge_type('Gene-Drug', 0, 1, [ adj_mat.transpose(0, 1) ], dedicom_decoder) + print('OK') + + print('Preparing Drug-Drug (Side Effect) edges...') + fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True) + print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique())) + adjacency_matrices = [] + side_effect_names = [] + for eff, df in df_combo.groupby('Polypharmacy Side Effect'): + sys.stdout.write('.') # print(eff, '...') + sys.stdout.flush() + rows = [index(drugs, d) for d in df['STITCH 1']] + cols = [index(drugs, d) for d in df['STITCH 2']] + indices = list(zip(rows, cols)) + indices = torch.tensor(indices).transpose(0, 1) + values = torch.ones(len(rows)) + adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)), + device=dev) + adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2 + adjacency_matrices.append(adj_mat) + side_effect_names.append(df['Polypharmacy Side Effect']) + fam.add_edge_type('Drug-Drug', 1, 1, adjacency_matrices, dedicom_decoder) + print() + print('OK') + + return data + + +def _wrap(obj, method_name): + orig_fn = getattr(obj, method_name) + def fn(*args, **kwargs): + print(f'{method_name}() :: ENTER') + res = orig_fn(*args, **kwargs) + print(f'{method_name}() :: EXIT') + return res + setattr(obj, method_name, fn) + + +def main(): + dev = torch.device('cuda:0') + data = load_data(dev) + + train_data, val_data, test_data = split_data(data, (.8, .1, .1)) + + n = sum(vt.count for vt in data.vertex_types) + + model = Model(data, [n, 32, 64], keep_prob=.9, + conv_activation=torch.sigmoid, + dec_activation=torch.sigmoid).to(dev) + + initial_repr = common_one_hot_encoding([ vt.count \ + for vt in data.vertex_types ], device=dev) + + loop = TrainLoop(model, val_data, test_data, + initial_repr, max_epochs=50, batch_size=512, + loss=torch.nn.functional.binary_cross_entropy_with_logits, + lr=0.001) + + loop.run() + + with open('/pstore/data/data_science/year/2020/adaszews/models/triacontagon/basic_run.pth', 'wb') as f: + torch.save(model.state_dict(), f) + + +if __name__ == '__main__': + main() diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py index f268f44..70067f8 100644 --- a/src/triacontagon/util.py +++ b/src/triacontagon/util.py @@ -225,7 +225,7 @@ def _select_rows(a: torch.Tensor, rows: torch.Tensor): return res -def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ +def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \ List[torch.Tensor]: tot = sum(vertex_type_counts) @@ -241,6 +241,7 @@ def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ ]) val = torch.ones(cnt) x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) + x = x.to(device) res.append(x) ofs += cnt