IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
5.2KB

  1. #!/usr/bin/env python3
  2. from triacontagon.data import Data
  3. from triacontagon.split import split_data
  4. from triacontagon.model import Model
  5. from triacontagon.loop import TrainLoop
  6. from triacontagon.decode import dedicom_decoder
  7. from triacontagon.util import common_one_hot_encoding
  8. import os
  9. import pandas as pd
  10. from bisect import bisect_left
  11. import torch
  12. import sys
  13. def index(a, x):
  14. i = bisect_left(a, x)
  15. if i != len(a) and a[i] == x:
  16. return i
  17. raise ValueError
  18. def load_data(dev):
  19. path = '/pstore/data/data_science/ref/decagon'
  20. df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv'))
  21. df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv'))
  22. df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv'))
  23. df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv'))
  24. df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv'))
  25. df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv'))
  26. lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ]
  27. for nam in lst:
  28. print(f'len({nam}): {len(locals()[nam])}')
  29. print(f'{nam}.columns: {locals()[nam].columns}')
  30. genes = set()
  31. genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \
  32. .union(df_tgtall['Gene']).union(df_tgt['Gene'])
  33. genes = sorted(genes)
  34. print('len(genes):', len(genes))
  35. drugs = set()
  36. drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \
  37. .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH'])
  38. drugs = sorted(drugs)
  39. print('len(drugs):', len(drugs))
  40. data = Data()
  41. data.add_vertex_type('Gene', len(genes))
  42. data.add_vertex_type('Drug', len(drugs))
  43. print('Preparing PPI...')
  44. print('Indexing rows...')
  45. rows = [index(genes, g) for g in df_ppi['Gene 1']]
  46. print('Indexing cols...')
  47. cols = [index(genes, g) for g in df_ppi['Gene 2']]
  48. indices = list(zip(rows, cols))
  49. indices = torch.tensor(indices).transpose(0, 1)
  50. values = torch.ones(len(rows))
  51. print('indices.shape:', indices.shape, 'values.shape:', values.shape)
  52. adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2,
  53. device=dev)
  54. adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
  55. print('adj_mat created')
  56. data.add_edge_type('PPI', 0, 0, [ adj_mat ], dedicom_decoder)
  57. print('OK')
  58. print('Preparing Drug-Gene (Target) edges...')
  59. rows = [index(drugs, d) for d in df_tgtall['STITCH']]
  60. cols = [index(genes, g) for g in df_tgtall['Gene']]
  61. indices = list(zip(rows, cols))
  62. indices = torch.tensor(indices).transpose(0, 1)
  63. values = torch.ones(len(rows))
  64. adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)),
  65. device=dev)
  66. data.add_edge_type('Drug-Gene', 1, 0, [ adj_mat ], dedicom_decoder)
  67. data.add_edge_type('Gene-Drug', 0, 1, [ adj_mat.transpose(0, 1) ], dedicom_decoder)
  68. print('OK')
  69. print('Preparing Drug-Drug (Side Effect) edges...')
  70. fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True)
  71. print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique()))
  72. adjacency_matrices = []
  73. side_effect_names = []
  74. for eff, df in df_combo.groupby('Polypharmacy Side Effect'):
  75. sys.stdout.write('.') # print(eff, '...')
  76. sys.stdout.flush()
  77. rows = [index(drugs, d) for d in df['STITCH 1']]
  78. cols = [index(drugs, d) for d in df['STITCH 2']]
  79. indices = list(zip(rows, cols))
  80. indices = torch.tensor(indices).transpose(0, 1)
  81. values = torch.ones(len(rows))
  82. adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)),
  83. device=dev)
  84. adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
  85. adjacency_matrices.append(adj_mat)
  86. side_effect_names.append(df['Polypharmacy Side Effect'])
  87. fam.add_edge_type('Drug-Drug', 1, 1, adjacency_matrices, dedicom_decoder)
  88. print()
  89. print('OK')
  90. return data
  91. def _wrap(obj, method_name):
  92. orig_fn = getattr(obj, method_name)
  93. def fn(*args, **kwargs):
  94. print(f'{method_name}() :: ENTER')
  95. res = orig_fn(*args, **kwargs)
  96. print(f'{method_name}() :: EXIT')
  97. return res
  98. setattr(obj, method_name, fn)
  99. def main():
  100. dev = torch.device('cuda:0')
  101. data = load_data(dev)
  102. train_data, val_data, test_data = split_data(data, (.8, .1, .1))
  103. n = sum(vt.count for vt in data.vertex_types)
  104. model = Model(data, [n, 32, 64], keep_prob=.9,
  105. conv_activation=torch.sigmoid,
  106. dec_activation=torch.sigmoid).to(dev)
  107. initial_repr = common_one_hot_encoding([ vt.count \
  108. for vt in data.vertex_types ], device=dev)
  109. loop = TrainLoop(model, val_data, test_data,
  110. initial_repr, max_epochs=50, batch_size=512,
  111. loss=torch.nn.functional.binary_cross_entropy_with_logits,
  112. lr=0.001)
  113. loop.run()
  114. with open('/pstore/data/data_science/year/2020/adaszews/models/triacontagon/basic_run.pth', 'wb') as f:
  115. torch.save(model.state_dict(), f)
  116. if __name__ == '__main__':
  117. main()