#!/usr/bin/env python3 from triacontagon.data import Data from triacontagon.split import split_data from triacontagon.model import Model from triacontagon.loop import TrainLoop from triacontagon.decode import dedicom_decoder from triacontagon.util import common_one_hot_encoding import os import pandas as pd from bisect import bisect_left import torch import sys def index(a, x): i = bisect_left(a, x) if i != len(a) and a[i] == x: return i raise ValueError def load_data(dev): path = '/pstore/data/data_science/ref/decagon' df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv')) df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv')) df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv')) df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv')) df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv')) df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv')) lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ] for nam in lst: print(f'len({nam}): {len(locals()[nam])}') print(f'{nam}.columns: {locals()[nam].columns}') genes = set() genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \ .union(df_tgtall['Gene']).union(df_tgt['Gene']) genes = sorted(genes) print('len(genes):', len(genes)) drugs = set() drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \ .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH']) drugs = sorted(drugs) print('len(drugs):', len(drugs)) data = Data() data.add_vertex_type('Gene', len(genes)) data.add_vertex_type('Drug', len(drugs)) print('Preparing PPI...') print('Indexing rows...') rows = [index(genes, g) for g in df_ppi['Gene 1']] print('Indexing cols...') cols = [index(genes, g) for g in df_ppi['Gene 2']] indices = list(zip(rows, cols)) indices = torch.tensor(indices).transpose(0, 1) values = torch.ones(len(rows)) print('indices.shape:', indices.shape, 'values.shape:', values.shape) adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2, device=dev) adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2 print('adj_mat created') data.add_edge_type('PPI', 0, 0, [ adj_mat ], dedicom_decoder) print('OK') print('Preparing Drug-Gene (Target) edges...') rows = [index(drugs, d) for d in df_tgtall['STITCH']] cols = [index(genes, g) for g in df_tgtall['Gene']] indices = list(zip(rows, cols)) indices = torch.tensor(indices).transpose(0, 1) values = torch.ones(len(rows)) adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)), device=dev) data.add_edge_type('Drug-Gene', 1, 0, [ adj_mat ], dedicom_decoder) data.add_edge_type('Gene-Drug', 0, 1, [ adj_mat.transpose(0, 1) ], dedicom_decoder) print('OK') print('Preparing Drug-Drug (Side Effect) edges...') fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True) print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique())) adjacency_matrices = [] side_effect_names = [] for eff, df in df_combo.groupby('Polypharmacy Side Effect'): sys.stdout.write('.') # print(eff, '...') sys.stdout.flush() rows = [index(drugs, d) for d in df['STITCH 1']] cols = [index(drugs, d) for d in df['STITCH 2']] indices = list(zip(rows, cols)) indices = torch.tensor(indices).transpose(0, 1) values = torch.ones(len(rows)) adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)), device=dev) adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2 adjacency_matrices.append(adj_mat) side_effect_names.append(df['Polypharmacy Side Effect']) fam.add_edge_type('Drug-Drug', 1, 1, adjacency_matrices, dedicom_decoder) print() print('OK') return data def _wrap(obj, method_name): orig_fn = getattr(obj, method_name) def fn(*args, **kwargs): print(f'{method_name}() :: ENTER') res = orig_fn(*args, **kwargs) print(f'{method_name}() :: EXIT') return res setattr(obj, method_name, fn) def main(): dev = torch.device('cuda:0') data = load_data(dev) train_data, val_data, test_data = split_data(data, (.8, .1, .1)) n = sum(vt.count for vt in data.vertex_types) model = Model(data, [n, 32, 64], keep_prob=.9, conv_activation=torch.sigmoid, dec_activation=torch.sigmoid).to(dev) initial_repr = common_one_hot_encoding([ vt.count \ for vt in data.vertex_types ], device=dev) loop = TrainLoop(model, val_data, test_data, initial_repr, max_epochs=50, batch_size=512, loss=torch.nn.functional.binary_cross_entropy_with_logits, lr=0.001) loop.run() with open('/pstore/data/data_science/year/2020/adaszews/models/triacontagon/basic_run.pth', 'wb') as f: torch.save(model.state_dict(), f) if __name__ == '__main__': main()