|
- #!/usr/bin/env python3
-
- from icosagon.data import Data
- from icosagon.trainprep import TrainValTest, \
- prepare_training
- from icosagon.model import Model
- from icosagon.trainloop import TrainLoop
- import os
- import pandas as pd
- from bisect import bisect_left
- import torch
- import sys
-
-
- def index(a, x):
- i = bisect_left(a, x)
- if i != len(a) and a[i] == x:
- return i
- raise ValueError
-
-
- def load_data():
- path = '/pstore/data/data_science/ref/decagon'
- df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv'))
- df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv'))
- df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv'))
- df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv'))
- df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv'))
- df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv'))
- lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ]
- for nam in lst:
- print(f'len({nam}): {len(locals()[nam])}')
- print(f'{nam}.columns: {locals()[nam].columns}')
-
- genes = set()
- genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \
- .union(df_tgtall['Gene']).union(df_tgt['Gene'])
- genes = sorted(genes)
- print('len(genes):', len(genes))
-
- drugs = set()
- drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \
- .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH'])
- drugs = sorted(drugs)
- print('len(drugs):', len(drugs))
-
- data = Data()
- data.add_node_type('Gene', len(genes))
- data.add_node_type('Drug', len(drugs))
-
- print('Preparing PPI...')
- print('Indexing rows...')
- rows = [index(genes, g) for g in df_ppi['Gene 1']]
- print('Indexing cols...')
- cols = [index(genes, g) for g in df_ppi['Gene 2']]
- indices = list(zip(rows, cols))
- indices = torch.tensor(indices).transpose(0, 1)
- values = torch.ones(len(rows))
- print('indices.shape:', indices.shape, 'values.shape:', values.shape)
- adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2)
- adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
- print('adj_mat created')
- fam = data.add_relation_family('PPI', 0, 0, True)
- rel = fam.add_relation_type('PPI', adj_mat)
- print('OK')
-
- print('Preparing Drug-Gene (Target) edges...')
- rows = [index(drugs, d) for d in df_tgtall['STITCH']]
- cols = [index(genes, g) for g in df_tgtall['Gene']]
- indices = list(zip(rows, cols))
- indices = torch.tensor(indices).transpose(0, 1)
- values = torch.ones(len(rows))
- adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)))
- fam = data.add_relation_family('Drug-Gene (Target)', 1, 0, True)
- rel = fam.add_relation_type('Drug-Gene (Target)', adj_mat)
- print('OK')
-
- print('Preparing Drug-Drug (Side Effect) edges...')
- fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True)
- print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique()))
- for eff, df in df_combo.groupby('Polypharmacy Side Effect'):
- sys.stdout.write('.') # print(eff, '...')
- sys.stdout.flush()
- rows = [index(drugs, d) for d in df['STITCH 1']]
- cols = [index(drugs, d) for d in df['STITCH 2']]
- indices = list(zip(rows, cols))
- indices = torch.tensor(indices).transpose(0, 1)
- values = torch.ones(len(rows))
- adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)))
- adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
- rel = fam.add_relation_type(df['Polypharmacy Side Effect'], adj_mat)
- print()
- print('OK')
-
- return data
-
-
- def _wrap(obj, method_name):
- orig_fn = getattr(obj, method_name)
- def fn(*args, **kwargs):
- print(f'{method_name}() :: ENTER')
- res = orig_fn(*args, **kwargs)
- print(f'{method_name}() :: EXIT')
- return res
- setattr(obj, method_name, fn)
-
-
- def main():
- data = load_data()
- prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
- _wrap(Model, 'build')
- model = Model(prep_d)
- _wrap(TrainLoop, 'build')
- _wrap(TrainLoop, 'run_epoch')
- loop = TrainLoop(model, batch_size=1000000)
- loop.run_epoch()
-
-
- if __name__ == '__main__':
- main()
|