|
@@ -0,0 +1,146 @@ |
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
|
|
|
|
from triacontagon.data import Data
|
|
|
|
|
|
from triacontagon.split import split_data
|
|
|
|
|
|
from triacontagon.model import Model
|
|
|
|
|
|
from triacontagon.loop import TrainLoop
|
|
|
|
|
|
from triacontagon.decode import dedicom_decoder
|
|
|
|
|
|
from triacontagon.util import common_one_hot_encoding
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
from bisect import bisect_left
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def index(a, x):
|
|
|
|
|
|
i = bisect_left(a, x)
|
|
|
|
|
|
if i != len(a) and a[i] == x:
|
|
|
|
|
|
return i
|
|
|
|
|
|
raise ValueError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(dev):
|
|
|
|
|
|
path = '/pstore/data/data_science/ref/decagon'
|
|
|
|
|
|
|
|
|
|
|
|
df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv'))
|
|
|
|
|
|
df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv'))
|
|
|
|
|
|
df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv'))
|
|
|
|
|
|
df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv'))
|
|
|
|
|
|
df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv'))
|
|
|
|
|
|
df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv'))
|
|
|
|
|
|
|
|
|
|
|
|
lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ]
|
|
|
|
|
|
|
|
|
|
|
|
for nam in lst:
|
|
|
|
|
|
print(f'len({nam}): {len(locals()[nam])}')
|
|
|
|
|
|
print(f'{nam}.columns: {locals()[nam].columns}')
|
|
|
|
|
|
|
|
|
|
|
|
genes = set()
|
|
|
|
|
|
genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \
|
|
|
|
|
|
.union(df_tgtall['Gene']).union(df_tgt['Gene'])
|
|
|
|
|
|
genes = sorted(genes)
|
|
|
|
|
|
print('len(genes):', len(genes))
|
|
|
|
|
|
|
|
|
|
|
|
drugs = set()
|
|
|
|
|
|
drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \
|
|
|
|
|
|
.union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH'])
|
|
|
|
|
|
drugs = sorted(drugs)
|
|
|
|
|
|
print('len(drugs):', len(drugs))
|
|
|
|
|
|
|
|
|
|
|
|
data = Data()
|
|
|
|
|
|
data.add_vertex_type('Gene', len(genes))
|
|
|
|
|
|
data.add_vertex_type('Drug', len(drugs))
|
|
|
|
|
|
|
|
|
|
|
|
print('Preparing PPI...')
|
|
|
|
|
|
print('Indexing rows...')
|
|
|
|
|
|
rows = [index(genes, g) for g in df_ppi['Gene 1']]
|
|
|
|
|
|
print('Indexing cols...')
|
|
|
|
|
|
cols = [index(genes, g) for g in df_ppi['Gene 2']]
|
|
|
|
|
|
indices = list(zip(rows, cols))
|
|
|
|
|
|
indices = torch.tensor(indices).transpose(0, 1)
|
|
|
|
|
|
values = torch.ones(len(rows))
|
|
|
|
|
|
print('indices.shape:', indices.shape, 'values.shape:', values.shape)
|
|
|
|
|
|
adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2,
|
|
|
|
|
|
device=dev)
|
|
|
|
|
|
adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
|
|
|
|
|
|
print('adj_mat created')
|
|
|
|
|
|
data.add_edge_type('PPI', 0, 0, [ adj_mat ], dedicom_decoder)
|
|
|
|
|
|
print('OK')
|
|
|
|
|
|
|
|
|
|
|
|
print('Preparing Drug-Gene (Target) edges...')
|
|
|
|
|
|
rows = [index(drugs, d) for d in df_tgtall['STITCH']]
|
|
|
|
|
|
cols = [index(genes, g) for g in df_tgtall['Gene']]
|
|
|
|
|
|
indices = list(zip(rows, cols))
|
|
|
|
|
|
indices = torch.tensor(indices).transpose(0, 1)
|
|
|
|
|
|
values = torch.ones(len(rows))
|
|
|
|
|
|
adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)),
|
|
|
|
|
|
device=dev)
|
|
|
|
|
|
data.add_edge_type('Drug-Gene', 1, 0, [ adj_mat ], dedicom_decoder)
|
|
|
|
|
|
data.add_edge_type('Gene-Drug', 0, 1, [ adj_mat.transpose(0, 1) ], dedicom_decoder)
|
|
|
|
|
|
print('OK')
|
|
|
|
|
|
|
|
|
|
|
|
print('Preparing Drug-Drug (Side Effect) edges...')
|
|
|
|
|
|
fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True)
|
|
|
|
|
|
print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique()))
|
|
|
|
|
|
adjacency_matrices = []
|
|
|
|
|
|
side_effect_names = []
|
|
|
|
|
|
for eff, df in df_combo.groupby('Polypharmacy Side Effect'):
|
|
|
|
|
|
sys.stdout.write('.') # print(eff, '...')
|
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
rows = [index(drugs, d) for d in df['STITCH 1']]
|
|
|
|
|
|
cols = [index(drugs, d) for d in df['STITCH 2']]
|
|
|
|
|
|
indices = list(zip(rows, cols))
|
|
|
|
|
|
indices = torch.tensor(indices).transpose(0, 1)
|
|
|
|
|
|
values = torch.ones(len(rows))
|
|
|
|
|
|
adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)),
|
|
|
|
|
|
device=dev)
|
|
|
|
|
|
adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
|
|
|
|
|
|
adjacency_matrices.append(adj_mat)
|
|
|
|
|
|
side_effect_names.append(df['Polypharmacy Side Effect'])
|
|
|
|
|
|
fam.add_edge_type('Drug-Drug', 1, 1, adjacency_matrices, dedicom_decoder)
|
|
|
|
|
|
print()
|
|
|
|
|
|
print('OK')
|
|
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap(obj, method_name):
|
|
|
|
|
|
orig_fn = getattr(obj, method_name)
|
|
|
|
|
|
def fn(*args, **kwargs):
|
|
|
|
|
|
print(f'{method_name}() :: ENTER')
|
|
|
|
|
|
res = orig_fn(*args, **kwargs)
|
|
|
|
|
|
print(f'{method_name}() :: EXIT')
|
|
|
|
|
|
return res
|
|
|
|
|
|
setattr(obj, method_name, fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
dev = torch.device('cuda:0')
|
|
|
|
|
|
data = load_data(dev)
|
|
|
|
|
|
|
|
|
|
|
|
train_data, val_data, test_data = split_data(data, (.8, .1, .1))
|
|
|
|
|
|
|
|
|
|
|
|
n = sum(vt.count for vt in data.vertex_types)
|
|
|
|
|
|
|
|
|
|
|
|
model = Model(data, [n, 32, 64], keep_prob=.9,
|
|
|
|
|
|
conv_activation=torch.sigmoid,
|
|
|
|
|
|
dec_activation=torch.sigmoid).to(dev)
|
|
|
|
|
|
|
|
|
|
|
|
initial_repr = common_one_hot_encoding([ vt.count \
|
|
|
|
|
|
for vt in data.vertex_types ], device=dev)
|
|
|
|
|
|
|
|
|
|
|
|
loop = TrainLoop(model, val_data, test_data,
|
|
|
|
|
|
initial_repr, max_epochs=50, batch_size=512,
|
|
|
|
|
|
loss=torch.nn.functional.binary_cross_entropy_with_logits,
|
|
|
|
|
|
lr=0.001)
|
|
|
|
|
|
|
|
|
|
|
|
loop.run()
|
|
|
|
|
|
|
|
|
|
|
|
with open('/pstore/data/data_science/year/2020/adaszews/models/triacontagon/basic_run.pth', 'wb') as f:
|
|
|
|
|
|
torch.save(model.state_dict(), f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
main()
|