├── .gitignore
├── LICENSE
├── README.md
├── images
└── pipeline_updated_kang2.png
├── load_embed.py
├── load_file.py
├── main.py
├── models
├── drug_model.py
└── root
├── predict.sh
├── predict_example.sh
├── predict_zinc.sh
├── tasks
├── .ipynb_checkpoints
│ └── plot-checkpoint.ipynb
├── drug_run.py
├── drug_task.py
├── plot.ipynb
├── plot.py
└── run_plot.py
├── test.sh
├── train.sh
├── train_ensemble.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | tasks/data/
2 | results/
3 | __pycache__/
4 |
5 | *.pyc
6 | *.swp
7 | tags
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReSimNet
2 | A Pytorch Implementation of paper
3 | > ReSimNet: Drug Response Similarity Prediction using Siamese Neural Networks
4 | > Jeon and Park et al., 2018
5 |
6 | ## Abstract
7 | Traditional drug discovery approaches identify a target for a disease and find a compound that binds to the target. In this approach, structures of compounds are considered as the most important features because it is assumed that similar structures will bind to the same target. Therefore, structural analogs of the drugs that bind to the target are selected as drug candidates. However, even though compounds are not structural analogs, they may achieve the desired response. A new drug discovery method based on drug response, which can complement the structure-based methods, is needed.
8 |
9 | We implemented Siamese neural networks called ReSimNet that take as input two chemical compounds and predicts the CMap score of the two compounds, which we use to measure the transcriptional response similarity of the two counpounds. ReSimNet learns the embedding vector of a chemical compound in a transcriptional response space. ReSimNet is trained to minimize the difference between the cosine similarity of the embedding vectors of the two compounds and the CMap score of the two compounds. ReSimNet can find pairs of compounds that are similar in response even though they may have dissimilar structures. In our quantitative evaluation, ReSimNet outperformed the baseline machine learning models. The ReSimNet ensemble model achieves a Pearson correlation of 0.518 and a precision@1% of 0.989. In addition, in the qualitative analysis, we tested ReSimNet on the ZINC15 database and showed that ReSimNet successfully identifies chemical compounds that are relevant to a prototype drug whose mechanism of action is known.
10 |
11 | ## Pipeline
12 | 
13 |
14 | ## Requirements
15 | - Install [cuda-8.0](https://developer.nvidia.com/cuda-downlaods)
16 | - Install [cudnn-v5.1](https://developer.nvidia.com/cudnn)
17 | - Install [Pytorch 0.3.0](https://pytorch.org/)
18 | - Install [Numpy 1.61.1](https://pypi.org/project/numpy/)
19 | - Python version >= 3.4.3 is required
20 |
21 | ## Git Clone & Initial Setting
22 | Clone our source codes and make folders to save data you need.
23 |
24 | ```bash
25 | # clone the source code on your directory
26 | $ git clone https://github.com/dmis-lab/ReSimNet
27 | $ cd ReSimNet
28 |
29 | # make folder to save and load your data
30 | $ cd tasks
31 | $ mkdir -p data
32 |
33 | # make folder to save and load your model
34 | cd ../../..
35 | $ mkdir -p results
36 | ```
37 |
38 | ## Download Files You Need to Run ReSimNet
39 |
40 | ### Dataset for Training
41 | - **[ReSimNet-Dataset.pkl](https://drive.google.com/file/d/1iju2oBxnRW9jAnoeyOGDe9_hTBwh-vlT/view?usp=drive_link) (43MB)**
42 | *Save this file to ./ReSimNet/tasks/data/**ReSimNet-Dataset.pkl***
43 |
44 | ### Pre-Trained Models
45 | - **[ReSimNet-model-best.zip](https://drive.google.com/file/d/1hgEFKgrB1BeKRMxFYXmy9mlmYDpwAX7r/view?usp=drive_link) (12MB)**
46 | *Save this file to ./ReSimNet/results/**ReSimNet-models-best.zip** and Unzip.*
47 |
48 | ### All 10 Models for Ensemble
49 | - **[ReSimNet-models-ensenble.zip](https://drive.google.com/file/d/1CapiepxBByB8koXWXqtL7e2_c-IHC4ya/view?usp=drive_link) (117MB)**
50 | *Save this file to ./ReSimNet/results/**ReSimNet-model-ensemble.zip** and Unzip.*
51 |
52 | ### Example Input Pairs
53 | - **[examples.csv](https://drive.google.com/file/d/16Vdvt8LrGfuo7RJhaEJVzKo-PcnB5SNt/view?usp=drive_link) (244byte)**
54 | *Save this file to ./ReSimNet/tasks/data/pairs/**examples.csv***
55 |
56 | ### Click the link ""Download the FingerPrint Respresentation"".
57 | - **[pertid2fingerprint.pkl](https://drive.google.com/file/d/1zK3693qPDxUZL7uRIutbA7KdmfkJtn29/view?usp=drive_link) (10MB)**
58 | *Save this file to ./ReSimNet/tasks/data/**pertid2fingerprint.pkl***
59 |
60 |
61 | ## Training the ReSimNet
62 |
63 | ```bash
64 | # Train for new model.
65 | $ bash train.sh
66 |
67 | # Train for the new ensemble models.
68 | $ bast train_ensemble.sh
69 | ```
70 |
71 | ## CMap Score Prediction using ReSimNet
72 | For your own fingerprint pairs, ReSimNet provides a predicted CMap score for each pair. Running download.sh and predict.sh will first download pretrained ReSimNet with sample datasets, and save a result file for predicted CMap scores.
73 | ```bash
74 | # Save scores of sample pair data
75 | $ bash predict_example.sh
76 | ```
77 | Input Fingerprint pair file must be a .csv file in which every row consists of two columns denoting two Fingerprints of each pair. Please, place files under './tasks/data/pairs/'.
78 | ```bash
79 | # Sample Fingerprints (./tasks/data/pairs/examples.csv)
80 | id1,id2
81 | BRD-K43164539,BRD-A45333398
82 | BRD-K83289131,BRD-K82484965
83 | BRD-K06817181,BRD-A41112154
84 | BRD-K06817181,BRD-K67977190
85 | BRD-K06817181,BRD-A87125127
86 | BRD-K68095457,BRD-K38903228
87 | BRD-K68095457,BRD-K01902415
88 | BRD-K68095457,BRD-K06817181
89 | ```
90 | Predicted CMap scores will be saved at each row of a file './results/input-pair-file.model-name.csv'.
91 | ```bash
92 | # Sample results (./results/examples.csv.ReSimNet7.csv')
93 | prediction
94 | 0.9146181344985962
95 | 0.9301251173019409
96 | 0.8519644737243652
97 | 0.9631381034851074
98 | 0.7272981405258179
99 | ```
100 | ## CMap Score Prediction of ZINC using ReSimNet
101 | ```bash
102 | # Save scores of sample pair data
103 | $ bash predict_zinc.sh
104 | ```
105 | ### Click the link ""Download the ZINC files"".
106 | - **[zinc-test.zip](https://drive.google.com/file/d/1RT7oSvJtjlOsoFaA_ZpFuQHoPSA26CrL/view?usp=drive_link) (8KB)**
107 | *Save this file to ./ReSimNet/tasks/data/pairs_zinc/**zinc-test.zip** and unzip.*
108 |
109 | ```bash
110 | # Sample Zinc files (./tasks/data/pairs_zinc/zinc-test/AACA.csv)
111 | ,smiles,zinc_id,inchikey,mwt,logp,reactive,purchasable,tranche_name,features,fingerprint
112 | 17,CC1NNC(=S)NN1,ZINC000018204142,BYIXAEICDPEBOP-UHFFFAOYSA-N,132.192,-1.181,10,50,AACA,,00000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
113 | ```
114 |
115 | ### Click the link ""Download the example pairings"".
116 | - **[example_drugs.csv](https://drive.google.com/file/d/1Rok_oU8mwZbFMgYFJTx4i4NRAdz8oIg7/view?usp=sharing) (7KB)**
117 | *Save this file to ./ReSimNet/tasks/data/pairs_zinc/**example_drugs.csv***
118 |
119 | ```bash
120 | # Sample example files (./tasks/data/pairs_zinc/example_drugs.csv)
121 | pair,fp
122 | ZINC18279871,00000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000100000000000000000000000000000000000000000000000000100000000000010010000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000010000000000000000000000000000000000000000000000000000000000000000000000100000000000100000000000000000000000000000000000010000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000001000000000000000000000000000000000000010000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000010000000001000000000000000000000000000000000000000000000001000000000000000000000010000000000000000000000000000000000000000000000000000000000000000000001000000000000000001000000000000000010000000000000000000000000000000000000000000000000000000000000000000000000100000000000001000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100001000000000000000000000000000001001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
123 | ZINC3938668,00000100000000000000000000000100000000000000000000000000000000000000000000100000100000001000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000100010000010000000000000000000000000000000000000000000000000000000000000100001001000000000000000000000000000101000010000000010000000000000000000000000000000001000000000000000000000000000000001000000000000100000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000010000000000000000000000001000100000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000101000000000100000000001000000000000000000000000000000000000010000010000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000100000000000000000100000000100000000000000010000100000000000000000100000000000000000000000000000100000000000000100000000100000000001000000000000000001001000000000000000000000000000100000001000000000000000001010000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000001000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000010000000000000000000000100000000000000000010100000000000000000000000000000000000000000000000000010001000000100000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000000000001000000001000000010000000010000000000000000000000000000000000000010000000000000000000000100001000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000011000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000100000000000000010000000000000000000000000000000000010000000000000
124 | ```
125 |
126 | Predicted CMap scores will be saved at each row of a file './results/input-pair-file.model-name.csv'.
127 | ```bash
128 | # Sample results (./results/AACA.csv.ReSimNet7.csv')
129 | pair1,pair2,prediction
130 | ZINC000018204142,ZINC18279871,0.90729403
131 | ZINC000018204142,ZINC3938668,0.91043824
132 | ```
133 |
134 | ## Liscense
135 | Apache License 2.0
136 |
--------------------------------------------------------------------------------
/images/pipeline_updated_kang2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dmis-lab/ReSimNet/bdb8786d514ac23221e0179a38bf9d6a999a354f/images/pipeline_updated_kang2.png
--------------------------------------------------------------------------------
/load_embed.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 |
4 |
5 | argparser = argparse.ArgumentParser()
6 |
7 | argparser.add_argument('--file', type=str, default='')
8 | args = argparser.parse_args()
9 | print(args)
10 |
11 |
12 | def main():
13 | dataset = pickle.load(open('./results/' + args.file, 'rb'))
14 | for key, value in dataset.items():
15 | print(key, value)
16 | break
17 |
18 |
19 | if __name__ == '__main__':
20 | main()
21 |
22 |
23 |
--------------------------------------------------------------------------------
/load_file.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | from tasks.drug_task import DrugDataset
4 |
5 | argparser = argparse.ArgumentParser()
6 |
7 | argparser.add_argument('--file-path', type=str,
8 | default='tasks/data/drug/drug(v0.5).pkl')
9 | argparser.add_argument('--save-path', type=str,
10 | default='results')
11 | args = argparser.parse_args()
12 |
13 | def main():
14 | pair = {}
15 |
16 | dataset_l = pickle.load(open(args.file_path, 'rb'))
17 | dataset = dataset_l.dataset
18 | k_set = dataset_l.known
19 | test_data = dataset['te']
20 | for idx, item in enumerate(test_data):
21 | d1 = item[0]
22 | d2 = item[1]
23 | ds = (d1, d2)
24 | if d1 in k_set and d2 in k_set:
25 | label = 'KK'
26 | elif d1 not in k_set and d2 not in k_set:
27 | label = 'UU'
28 | else:
29 | label = 'KU'
30 | pair[ds] = label
31 |
32 | pickle.dump(pair, open('{}/testset.pkl'.format(
33 | args.save_path), 'wb'))
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | import pickle
5 | import random
6 | import argparse
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | import torch.nn.functional as F
12 |
13 | from datetime import datetime
14 | from functools import partial
15 | from torch.autograd import Variable
16 |
17 | from tasks.drug_task import DrugDataset
18 | from tasks.drug_run import *
19 | from models.drug_model import DrugModel
20 | from models.root.utils import *
21 |
22 |
23 | LOGGER = logging.getLogger()
24 |
25 | # DATA_PATH = './tasks/data/cell_lines(v0.6).pkl' # Cell line pairs
26 | DATA_PATH = './tasks/data/drug(v0.6).pkl' # For training (Pair scores)
27 | # DATA_PATH = './tasks/data/drug/drug(v0.1_graph).pkl'
28 | DRUG_DIR = './tasks/data/drug/validation/' # For validation (ex: tox21)
29 | #DRUG_FILES = ['BBBP_fingerprint_3.pkl',
30 | # 'clintox_fingerprint_3.pkl',
31 | # 'sider_fingerprint_3.pkl',
32 | # 'tox21_fingerprint_3.pkl',
33 | # 'toxcast_fingerprint_3.pkl',]
34 | DRUG_FILES = ['drug(v0.5).pkl']
35 | PAIR_DIR = './tasks/data/pairs/zinc/KKEB.csv' # New pair data for scoring
36 | FP_DIR = './tasks/data/fingerprint_v0.6_py2.pkl'
37 | EXAMPLE_DIR = "./tasks/data/pairs_zinc/example_drugs.csv"
38 | CKPT_DIR = './results/'
39 | MODEL_NAME = 'model.mdl'
40 |
41 |
42 | def str2bool(v):
43 | return v.lower() in ('True', 'yes', 'true', 't', '1', 'y')
44 |
45 |
46 |
47 | # Run settings
48 | argparser = argparse.ArgumentParser()
49 | argparser.register('type', 'bool', str2bool)
50 |
51 | argparser.add_argument('--data-path', type=str, default=DATA_PATH,
52 | help='Dataset path')
53 | argparser.add_argument('--drug-dir', type=str, default=DRUG_DIR,
54 | help='Input drug dictionary')
55 | argparser.add_argument('--drug-files', type=str, default=DRUG_FILES,
56 | help='Input drug file')
57 | argparser.add_argument('--pair-dir', type=str, default=PAIR_DIR,
58 | help='Input new pairs')
59 | argparser.add_argument('--fp-dir', type=str, default=FP_DIR,
60 | help='Input new pairs')
61 | argparser.add_argument('--example-dir', type=str, default=EXAMPLE_DIR,
62 | help='Input new pairs')
63 | argparser.add_argument('--checkpoint-dir', type=str, default=CKPT_DIR,
64 | help='Directory for model checkpoint')
65 | argparser.add_argument('--model-name', type=str, default=MODEL_NAME,
66 | help='Model name for saving/loading')
67 | argparser.add_argument('--print-step', type=float, default=100,
68 | help='Display steps')
69 | argparser.add_argument('--validation-step', type=float, default=1,
70 | help='Number of random search validation')
71 | argparser.add_argument('--ensemble-step', type=float, default=10,
72 | help='Number of random search validation')
73 | argparser.add_argument('--train', type='bool', default=True,
74 | help='Enable training')
75 | argparser.add_argument('--pretrain', type='bool', default=False,
76 | help='Enable training')
77 | argparser.add_argument('--valid', type='bool', default=True,
78 | help='Enable validation')
79 | argparser.add_argument('--test', type='bool', default=True,
80 | help='Enable testing')
81 | argparser.add_argument('--resume', type='bool', default=False,
82 | help='Resume saved model')
83 | argparser.add_argument('--debug', type='bool', default=False,
84 | help='Run as debug mode')
85 | argparser.add_argument('--save-embed', type='bool', default=False,
86 | help='Save embeddings with loaded model')
87 | argparser.add_argument('--save-prediction', type='bool', default=False,
88 | help='Save predictions with loaded model')
89 | argparser.add_argument('--perform-ensemble', type='bool', default=False,
90 | help='perform-ensemble and save predictions with loaded model')
91 | argparser.add_argument('--save-pair-score', type='bool', default=False,
92 | help='Save predictions with loaded model')
93 | argparser.add_argument('--save-pair-score-zinc', type='bool', default=False,
94 | help='Save predictions with loaded model')
95 | argparser.add_argument('--save-pair-score-ensemble', type='bool', default=False,
96 | help='Save predictions with loaded model')
97 | argparser.add_argument('--top-only', type='bool', default=False,
98 | help='Return top/bottom 10% results only')
99 | argparser.add_argument('--embed-d', type = int, default=1,
100 | help='0:val task data, 1:v0.n data')
101 |
102 | # Train config
103 | argparser.add_argument('--batch-size', type=int, default=32)
104 | argparser.add_argument('--epoch', type=int, default=40)
105 | argparser.add_argument('--learning-rate', type=float, default=0.005)
106 | argparser.add_argument('--weight-decay', type=float, default=0)
107 | argparser.add_argument('--grad-max-norm', type=int, default=10)
108 | argparser.add_argument('--grad-clip', type=int, default=10)
109 |
110 | # Model config
111 | argparser.add_argument('--binary', type='bool', default=False)
112 | argparser.add_argument('--hidden-dim', type=int, default=512)
113 | argparser.add_argument('--drug-embed-dim', type=int, default=300)
114 | argparser.add_argument('--lstm-layer', type=int, default=1)
115 | argparser.add_argument('--lstm-dr', type=float, default=0.0)
116 | argparser.add_argument('--char-dr', type=float, default=0.0)
117 | argparser.add_argument('--bi-lstm', type='bool', default=True)
118 | argparser.add_argument('--linear-dr', type=float, default=0.0)
119 | argparser.add_argument('--char-embed-dim', type=int, default=20)
120 | argparser.add_argument('--s-idx', type=int, default=0)
121 | argparser.add_argument('--rep-idx', type=int, default=2)
122 | argparser.add_argument('--dist-fn', type=str, default='cos')
123 | argparser.add_argument('--seed', type=int, default=None)
124 |
125 | #graph
126 | argparser.add_argument('--g_layer', type=int, default = 3)
127 | argparser.add_argument('--g_hidden_dim', type=int, default=512)
128 | argparser.add_argument('--g_out_dim', type=int, default=300)
129 | argparser.add_argument('--g_dropout', type=float, default=0.0)
130 |
131 | args = argparser.parse_args()
132 |
133 |
134 | def run_experiment(model, dataset, run_fn, args, cell_line):
135 | print("Current Model: ", args.model_name)
136 | # Get dataloaders
137 | if cell_line is None:
138 | train_loader, valid_loader, test_loader = dataset.get_dataloader(
139 | batch_size=args.batch_size, s_idx=args.s_idx)
140 | else:
141 | LOGGER.info('Training on {} cell line'.format(cell_line))
142 | train_loader, valid_loader, test_loader = dataset.get_cellloader(
143 | batch_size=args.batch_size, s_idx=args.s_idx, cell_line=cell_line)
144 |
145 | # Set metrics
146 | if args.binary:
147 | from sklearn.metrics import precision_recall_fscore_support
148 | metric = partial(precision_recall_fscore_support, average='binary')
149 | assert args.s_idx == 1
150 | else:
151 | metric = np.corrcoef
152 | assert args.s_idx == 0
153 |
154 | # Save embeddings and exit
155 | if args.save_embed:
156 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
157 | # run_fn(model, test_loader, dataset, args, metric, train=False)
158 | if args.embed_d == 1:
159 | for drug_file in args.drug_files:
160 | drugs = pickle.load(open(args.drug_dir + drug_file, 'rb'))
161 | drugs = drugs.drugs
162 | save_embed(model, drugs, dataset, args, drug_file)
163 | else:
164 | for drug_file in args.drug_files:
165 | drugs = pickle.load(open(args.drug_dir + drug_file, 'rb'))
166 | save_embed(model, drugs, dataset, args, drug_file)
167 | sys.exit()
168 |
169 | # Save predictions on test dataset and exit
170 | if args.save_prediction:
171 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
172 | # run_fn(model, test_loader, dataset, args, metric, train=False)
173 | save_prediction(model, test_loader, dataset, args)
174 | sys.exit()
175 |
176 | if args.perform_ensemble:
177 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
178 | # run_fn(model, test_loader, dataset, args, metric, train=False)
179 | return perform_ensemble(model, test_loader, dataset, args)
180 |
181 |
182 | # Save pair predictions on pretrained model
183 | if args.save_pair_score:
184 | if args.save_pair_score_ensemble:
185 | models = [0,1,2,3,4,5,6,7,8,9]
186 | model_name = args.model_name.split(".")[0]
187 | for _model in models:
188 | print(model_name, _model)
189 | args.model_name = model_name+str(_model)+".mdl"
190 | print(args.model_name)
191 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
192 | # run_fn(model, test_loader, dataset, args, metric, train=False)
193 | if args.save_pair_score_zinc:
194 | save_pair_score_for_zinc(model, args.pair_dir, args.example_dir, dataset, args)
195 | else:
196 | save_pair_score(model, args.pair_dir, args.fp_dir, dataset, args)
197 | sys.exit()
198 |
199 | else:
200 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
201 | # run_fn(model, test_loader, dataset, args, metric, train=False)
202 | if args.save_pair_score_zinc:
203 | save_pair_score_for_zinc(model, args.pair_dir, args.example_dir, dataset, args)
204 | else:
205 | save_pair_score(model, args.pair_dir, args.fp_dir, dataset, args)
206 | sys.exit()
207 |
208 |
209 |
210 | # Save and load model during experiments
211 | if args.train:
212 | if args.resume:
213 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
214 |
215 | best = 0.0
216 | converge_cnt = 0
217 | adaptive_cnt = 0
218 | #lr_decay = 0
219 |
220 | for ep in range(args.epoch):
221 | LOGGER.info('Training Epoch %d' % (ep+1))
222 | run_fn(model, train_loader, dataset, args, metric, train=True)
223 |
224 | if args.valid:
225 | LOGGER.info('Validation')
226 | curr = run_fn(model, valid_loader, dataset, args,
227 | metric, train=False)
228 | if not args.resume and curr > best:
229 | best = curr
230 | model.save_checkpoint({
231 | 'state_dict': model.state_dict(),
232 | 'optimizer': model.optimizer.state_dict()},
233 | args.checkpoint_dir, args.model_name)
234 | converge_cnt = 0
235 | #lr_dacay = 0
236 | else:
237 | converge_cnt += 1
238 | # lr_decay += 1
239 | '''
240 | if lr_decay >= 2:
241 | old_lr = args.learning_rate
242 | args.learning_rate = 1/2 * args.learning_rate
243 | print("lr_decay from %.5f to %.5f" % (old_lr, args.learning_rate))
244 | lr_decay = 0
245 | '''
246 | if converge_cnt >= 3:
247 | for param_group in model.optimizer.param_groups:
248 | param_group['lr'] *= 0.5
249 | tmp_lr = param_group['lr']
250 | converge_cnt = 0
251 | adaptive_cnt += 1
252 | LOGGER.info('Adaptive {}: learning rate {:.4f}'.format(
253 | adaptive_cnt, model.optimizer.param_groups[0]['lr']))
254 |
255 | if adaptive_cnt > 3:
256 | LOGGER.info('Early stopping applied')
257 | break
258 |
259 | if args.test:
260 | LOGGER.info('Performance Test on Valid & Test Set')
261 | if args.train or args.resume:
262 | model.load_checkpoint(args.checkpoint_dir, args.model_name)
263 | run_fn(model, valid_loader, dataset, args, metric, train=False)
264 | run_fn(model, test_loader, dataset, args, metric, train=False)
265 |
266 |
267 | def get_dataset(path):
268 | return pickle.load(open(path, 'rb'))
269 |
270 |
271 | def get_run_fn(args):
272 | if args.binary:
273 | return run_bi
274 | else:
275 | return run_reg
276 |
277 |
278 | def get_model(args, dataset):
279 | dataset.set_rep(args.rep_idx)
280 | if args.rep_idx == 4:
281 | model = DrugModel(input_dim=dataset.input_dim,
282 | output_dim=1,
283 | hidden_dim=args.hidden_dim,
284 | drug_embed_dim=args.drug_embed_dim,
285 | lstm_layer=args.lstm_layer,
286 | lstm_dropout=args.lstm_dr,
287 | bi_lstm=args.bi_lstm,
288 | linear_dropout=args.linear_dr,
289 | char_vocab_size=len(dataset.char2idx),
290 | char_embed_dim=args.char_embed_dim,
291 | char_dropout=args.char_dr,
292 | dist_fn=args.dist_fn,
293 | learning_rate=args.learning_rate,
294 | binary=args.binary,
295 | is_mlp=False,
296 | weight_decay=args.weight_decay,
297 | is_graph=True,
298 | g_layer=args.g_layer,
299 | g_hidden_dim=args.g_hidden_dim,
300 | g_out_dim=args.g_out_dim,
301 | g_dropout=args.g_dropout).cuda()
302 |
303 | else:
304 | model = DrugModel(input_dim=dataset.input_dim,
305 | output_dim=1,
306 | hidden_dim=args.hidden_dim,
307 | drug_embed_dim=args.drug_embed_dim,
308 | lstm_layer=args.lstm_layer,
309 | lstm_dropout=args.lstm_dr,
310 | bi_lstm=args.bi_lstm,
311 | linear_dropout=args.linear_dr,
312 | char_vocab_size=len(dataset.char2idx),
313 | char_embed_dim=args.char_embed_dim,
314 | char_dropout=args.char_dr,
315 | dist_fn=args.dist_fn,
316 | learning_rate=args.learning_rate,
317 | binary=args.binary,
318 | is_mlp=args.rep_idx > 1,
319 | weight_decay=args.weight_decay,
320 | is_graph=False,
321 | g_layer=None,
322 | g_hidden_dim=None,
323 | g_out_dim=None,
324 | g_dropout=None).cuda()
325 | return model
326 |
327 |
328 | def init_logging(args):
329 | LOGGER.setLevel(logging.INFO)
330 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]',
331 | '%m/%d/%Y %I:%M:%S %p')
332 | console = logging.StreamHandler()
333 | console.setFormatter(fmt)
334 | LOGGER.addHandler(console)
335 |
336 | # For logfile writing
337 | logfile = logging.FileHandler(
338 | args.checkpoint_dir + 'logs/' + args.model_name + '.txt', 'w')
339 | logfile.setFormatter(fmt)
340 | LOGGER.addHandler(logfile)
341 |
342 |
343 | def init_seed(seed=None):
344 | if seed is None:
345 | seed = int(round(time.time() * 1000)) % 10000
346 |
347 | LOGGER.info("Using seed={}, pid={}".format(seed, os.getpid()))
348 | np.random.seed(seed)
349 | torch.manual_seed(seed)
350 | random.seed(seed)
351 |
352 |
353 | def init_parameters(args, model_name, model_idx, cell_line='Total'):
354 | args.model_name = '{}-{}-{}'.format(cell_line, model_name, model_idx)
355 | # args.learning_rate = np.random.uniform(1e-4, 2e-3)
356 | # args.batch_size = 2 ** np.random.randint(4, 7)
357 | # args.grad_max_norm = 5 * np.random.randint(1, 5)
358 | # args.hidden_dim = 64 * np.random.randint(1, 10)
359 | # args.drug_embed_dim = 50 * np.random.randint(1, 10)
360 |
361 |
362 | def main():
363 |
364 | # Initialize logging and prepare seed
365 | init_logging(args)
366 | LOGGER.info('COMMAND: {}'.format(' '.join(sys.argv)))
367 |
368 | # Get datset, run function, model
369 | dataset = get_dataset(args.data_path)
370 | run_fn = get_run_fn(args)
371 | cell_line = None
372 |
373 | if args.save_pair_score:
374 | LOGGER.info('save_pair_score step')
375 | init_seed(args.seed)
376 | # init_parameters(args, model_name, model_idx)
377 | # LOGGER.info(args)
378 |
379 | # Get model
380 | model = get_model(args, dataset)
381 |
382 | # Run experiment
383 | run_experiment(model, dataset, run_fn, args, cell_line)
384 |
385 | elif args.perform_ensemble:
386 | print("LET'S PERFORM ENSEMBLE!")
387 | ensemble_preds = []
388 | kk_ensemble_preds = []
389 | ku_ensemble_preds = []
390 | uu_ensemble_preds = []
391 |
392 | model_name = args.model_name.split(".")[0]
393 | for model_idx in range(args.ensemble_step):
394 | LOGGER.info('Ensemble step {}'.format(model_idx+1))
395 | init_seed(args.seed)
396 |
397 | model = get_model(args, dataset)
398 | print(model_name, _model)
399 | args.model_name = model_name+str(model_idx)+".mdl"
400 | print(args.model_name)
401 | pred_set, tar_set, kk_pred_set, kk_tar_set, ku_pred_set, ku_tar_set, uu_pred_set, uu_tar_set = run_experiment(model, dataset, run_fn, args, cell_line)
402 |
403 | ensemble_preds.append(pred_set)
404 | kk_ensemble_preds.append(kk_pred_set)
405 | ku_ensemble_preds.append(ku_pred_set)
406 | uu_ensemble_preds.append(uu_pred_set)
407 |
408 | print(pred_set[:10])
409 | print(tar_set[:10])
410 |
411 |
412 | #ensemble average
413 | ensemble_pred = np.array(ensemble_preds).mean(axis=0)
414 | kk_ensemble_pred = np.array(kk_ensemble_preds).mean(axis=0)
415 | ku_ensemble_pred = np.array(ku_ensemble_preds).mean(axis=0)
416 | uu_ensemble_pred = np.array(uu_ensemble_preds).mean(axis=0)
417 |
418 | print(ensemble_pred[:10])
419 | print(tar_set[:10])
420 |
421 | print("\n\nEnsemble Results")
422 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(ensemble_pred, tar_set)
423 | print('[TOTAL\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
424 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
425 |
426 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(kk_ensemble_pred, kk_tar_set)
427 | print('[KK\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
428 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
429 |
430 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(ku_ensemble_pred, ku_tar_set)
431 | print('[KU\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
432 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
433 |
434 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(uu_ensemble_pred, uu_tar_set)
435 | print('[UU\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
436 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
437 |
438 | else:
439 | print("LET'S PERFORM VALIDATION!")
440 | # Random search validation
441 | for model_idx in range(args.validation_step):
442 | LOGGER.info('Validation step {}'.format(model_idx+1))
443 | init_seed(args.seed)
444 | # init_parameters(args, model_name, model_idx)
445 | # LOGGER.info(args)
446 |
447 | # Get model
448 | model = get_model(args, dataset)
449 |
450 | # Run experiment
451 | run_experiment(model, dataset, run_fn, args, cell_line)
452 |
453 | if __name__ == '__main__':
454 | main()
455 |
--------------------------------------------------------------------------------
/models/drug_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import collections
7 | import math
8 | import sys
9 | import logging
10 |
11 | from torch.autograd import Variable
12 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
13 | from torch.nn.parameter import Parameter
14 |
15 | LOGGER = logging.getLogger(__name__)
16 |
17 |
18 | class DrugModel(nn.Module):
19 | def __init__(self, input_dim, output_dim, hidden_dim, drug_embed_dim,
20 | lstm_layer, lstm_dropout, bi_lstm, linear_dropout, char_vocab_size,
21 | char_embed_dim, char_dropout, dist_fn, learning_rate,
22 | binary, is_mlp, weight_decay, is_graph, g_layer,
23 | g_hidden_dim, g_out_dim, g_dropout):
24 |
25 | super(DrugModel, self).__init__()
26 |
27 | # Save model configs
28 | self.drug_embed_dim = drug_embed_dim
29 | self.lstm_layer = lstm_layer
30 | self.char_dropout = char_dropout
31 | self.dist_fn = dist_fn
32 | self.binary = binary
33 | self.is_mlp = is_mlp
34 | self.is_graph = is_graph
35 | self.g_layer = g_layer
36 | self.g_dropout = g_dropout
37 |
38 | #For rep_idx 4
39 | if is_graph:
40 | self.feature_dim = 75
41 | self.g_hidden_dim = g_hidden_dim
42 | self.g_out_dim = g_out_dim
43 | self.weight1 = Parameter(torch.FloatTensor(
44 | self.feature_dim, self.g_hidden_dim))
45 | self.weight2 = Parameter(torch.FloatTensor(
46 | self.g_hidden_dim, self.g_hidden_dim))
47 | self.weight3 = Parameter(torch.FloatTensor(
48 | self.g_hidden_dim, self.g_hidden_dim))
49 | self.weight4 = Parameter(torch.FloatTensor(
50 | self.g_hidden_dim, self.g_out_dim))
51 | #bias : option
52 | self.bias1 = Parameter(torch.FloatTensor(self.g_hidden_dim))
53 | self.bias2 = Parameter(torch.FloatTensor(self.g_hidden_dim))
54 | self.bias3 = Parameter(torch.FloatTensor(self.g_hidden_dim))
55 | self.bias4 = Parameter(torch.FloatTensor(self.g_out_dim))
56 | self.init_graph()
57 |
58 | # For rep_idx 0, 1
59 | elif not is_mlp:
60 | self.char_embed = nn.Embedding(char_vocab_size, char_embed_dim,
61 | padding_idx=0)
62 | self.lstm = nn.LSTM(char_embed_dim, drug_embed_dim, lstm_layer,
63 | bidirectional=False,
64 | batch_first=True, dropout=lstm_dropout)
65 | # For rep_ix 2, 3
66 | else:
67 | self.encoder = nn.Sequential(
68 | nn.Linear(input_dim, hidden_dim),
69 | #nn.Dropout(0.5),
70 | nn.ReLU(),
71 | # nn.Linear(hidden_dim, hidden_dim),
72 | # nn.ReLU(),
73 | nn.Linear(hidden_dim, drug_embed_dim),
74 | #nn.Dropout(0.2),
75 | )
76 | #self.init_layers()
77 |
78 | # Distance function
79 | self.dist_fc = nn.Linear(drug_embed_dim, 1)
80 |
81 | # Get params and register optimizer
82 | info, params = self.get_model_params()
83 | self.optimizer = optim.Adam(params, lr=learning_rate,
84 | weight_decay=weight_decay)
85 | # self.optimizer = optim.SGD(params, lr=learning_rate,
86 | # momentum=0.5)
87 | if binary:
88 | # self.criterion = nn.BCELoss()
89 | self.criterion = lambda x, y: y*torch.log(x) + (1-y)*torch.log(1-x)
90 | else:
91 | # self.criterion = nn.MSELoss(reduce=False)
92 | self.criterion = nn.MSELoss()
93 | LOGGER.info(info)
94 |
95 | def init_graph(self):
96 | stdv1 = 1. / math.sqrt(self.weight1.size(1))
97 | stdv2 = 1. / math.sqrt(self.weight2.size(1))
98 | stdv3 = 1. / math.sqrt(self.weight4.size(1))
99 |
100 | self.weight1.data.uniform_(-stdv1, stdv1)
101 | self.bias1.data.uniform_(-stdv1, stdv1)
102 | self.weight2.data.uniform_(-stdv2, stdv2)
103 | self.bias2.data.uniform_(-stdv2, stdv2)
104 | self.weight3.data.uniform_(-stdv2, stdv2)
105 | self.bias3.data.uniform_(-stdv2, stdv2)
106 | self.weight4.data.uniform_(-stdv3, stdv3)
107 | self.bias4.data.uniform_(-stdv3, stdv3)
108 |
109 | def init_lstm_h(self, batch_size):
110 | return (Variable(torch.zeros(
111 | self.lstm_layer*1, batch_size, self.drug_embed_dim)).cuda(),
112 | Variable(torch.zeros(
113 | self.lstm_layer*1, batch_size, self.drug_embed_dim)).cuda())
114 |
115 | def init_layers(self):
116 | nn.init.xavier_normal(self.encoder[0].weight.data)
117 | nn.init.xavier_normal(self.encoder[2].weight.data)
118 | # nn.init.xavier_normal(self.encoder[4].weight.data)
119 |
120 | # Set Siamese network as basic LSTM
121 | def siamese_sequence(self, inputs, length):
122 | # Character embedding
123 | c_embed = self.char_embed(inputs)
124 | # c_embed = F.dropout(c_embed, self.char_dropout)
125 | maxlen = inputs.size(1)
126 |
127 | if not self.training:
128 | # Sort c_embed
129 | _, sort_idx = torch.sort(length, dim=0, descending=True)
130 | _, unsort_idx = torch.sort(sort_idx, dim=0)
131 | maxlen = torch.max(length)
132 |
133 | # Pack padded sequence
134 | c_embed = c_embed.index_select(0, Variable(sort_idx).cuda())
135 | sorted_len = length.index_select(0, sort_idx).tolist()
136 | c_packed = pack_padded_sequence(c_embed, sorted_len, batch_first=True)
137 |
138 | else:
139 | c_packed = c_embed
140 |
141 | # Run LSTM
142 | init_lstm_h = self.init_lstm_h(inputs.size(0))
143 | lstm_out, states = self.lstm(c_packed, init_lstm_h)
144 |
145 | hidden = torch.transpose(states[0], 0, 1).contiguous().view(
146 | -1, 1 * self.drug_embed_dim)
147 | if not self.training:
148 | # Unsort hidden states
149 | outputs = hidden.index_select(0, Variable(unsort_idx).cuda())
150 | else:
151 | outputs = hidden
152 |
153 | return outputs
154 |
155 | def graph_conv(self, features, adjs):
156 | weight1 = self.weight1.unsqueeze(0).expand(
157 | features.size(0), self.weight1.size(0), self.weight1.size(1))
158 | support1 = torch.bmm(features, weight1)
159 | layer1 = torch.bmm(adjs, support1)
160 | layer1_out = F.dropout(F.relu(layer1 + self.bias1),
161 | self.g_dropout)
162 |
163 | weight2 = self.weight2.unsqueeze(0).expand(
164 | layer1_out.size(0), self.weight2.size(0), self.weight2.size(1))
165 | support2 = torch.bmm(layer1_out, weight2)
166 | layer2 = torch.bmm(adjs, support2)
167 | layer2_out = F.dropout(F.relu(layer2 + self.bias2),
168 | self.g_dropout)
169 |
170 | weight3 = self.weight3.unsqueeze(0).expand(
171 | layer2_out.size(0), self.weight3.size(0), self.weight3.size(1))
172 | support3 = torch.bmm(layer2_out, weight3)
173 | layer3 = torch.bmm(adjs, support3)
174 | layer3_out = F.dropout(F.relu(layer3 + self.bias3),
175 | self.g_dropout)
176 | weight4 = self.weight4.unsqueeze(0).expand(
177 | layer3_out.size(0), self.weight4.size(0), self.weight4.size(1))
178 | support4 = torch.bmm(layer3_out, weight4)
179 | layer4 = torch.bmm(adjs, support4)
180 | layer4_out = layer4 + self.bias4
181 |
182 | graph_conv = F.log_softmax(layer4_out)
183 |
184 | #Choose pooling operation
185 | pool = nn.MaxPool1d(graph_conv.size(1))
186 | #pool = nn.AvgPool1d(graph_conv.size(1))
187 | graph_conv_embed = torch.squeeze(pool(torch.transpose(graph_conv,1,2)))
188 | return graph_conv_embed
189 |
190 |
191 | def siamese_basic(self, inputs):
192 | return self.encoder(inputs.float())
193 |
194 | def distance_layer(self, vec1, vec2, distance='cos'):
195 | if distance == 'cos':
196 | similarity = F.cosine_similarity(
197 | vec1 + 1e-16, vec2 + 1e-16, dim=-1)
198 | elif distance == 'l1':
199 | similarity = self.dist_fc(torch.abs(vec1 - vec2))
200 | similarity = similarity.squeeze(1)
201 | elif distance == 'l2':
202 | similarity = self.dist_fc(torch.abs(vec1 - vec2) ** 2)
203 | similarity = similarity.squeeze(1)
204 |
205 | if self.binary:
206 | similarity = F.sigmoid(similarity)
207 |
208 | return similarity
209 |
210 | def forward(self, key1, key1_len, key2, key2_len, key1_adj, key2_adj):
211 | if key1_adj is not None and key2_adj is not None:
212 | embed1 = self.graph_conv(key1, key1_adj)
213 | embed2 = self.graph_conv(key2, key2_adj)
214 |
215 | elif not self.is_mlp and not self.is_graph:
216 | embed1 = self.siamese_sequence(key1, key1_len)
217 | embed2 = self.siamese_sequence(key2, key2_len)
218 |
219 | else:
220 | embed1 = self.siamese_basic(key1)
221 | embed2 = self.siamese_basic(key2)
222 |
223 | similarity = self.distance_layer(embed1, embed2, self.dist_fn)
224 | return similarity, embed1, embed2
225 |
226 | def get_loss(self, outputs, targets):
227 | if not self.binary:
228 | loss = self.criterion(outputs, targets)
229 | # loss = torch.sum(loss * torch.abs(targets)) / loss.size(0)
230 | else:
231 | # loss = -1 * self.criterion(outputs, targets)
232 | # p_t = targets * outputs + (1 - targets) * (1 - outputs)
233 | # gamma = 2.
234 | # loss = torch.sum(((1 - p_t) ** gamma) * loss) / loss.size(0)
235 | loss = self.criterion(outputs, targets)
236 | return loss
237 |
238 | def get_model_params(self):
239 | params = []
240 | total_size = 0
241 |
242 | def multiply_iter(p_list):
243 | out = 1
244 | for p in p_list:
245 | out *= p
246 | return out
247 |
248 | for p in self.parameters():
249 | if p.requires_grad:
250 | params.append(p)
251 | total_size += multiply_iter(p.size())
252 |
253 | return '{}\nparam size: {:,}\n'.format(self, total_size), params
254 |
255 | def save_checkpoint(self, state, checkpoint_dir, filename):
256 | filename = checkpoint_dir + filename
257 | LOGGER.info('Save checkpoint %s' % filename)
258 | torch.save(state, filename)
259 |
260 | def load_checkpoint(self, checkpoint_dir, filename):
261 | filename = checkpoint_dir + filename
262 | LOGGER.info('Load checkpoint %s' % filename)
263 | checkpoint = torch.load(filename)
264 |
265 | self.load_state_dict(checkpoint['state_dict'])
266 | self.optimizer.load_state_dict(checkpoint['optimizer'])
267 |
--------------------------------------------------------------------------------
/models/root:
--------------------------------------------------------------------------------
1 | ../
--------------------------------------------------------------------------------
/predict.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # predict pair scores when given with two input drud_ids.
3 | # calculate prediction scores based on averged scores of all 10 models.
4 | # if you do not want this, set --save-pair-score-ensemble to false
5 | python main.py --save-pair-score true --pair-dir './tasks/data/pairs/' --fp-dir './tasks/data/pertid2fingerprint.pkl' --data-path './tasks/data/ReSimNet-Dataset.pkl' --model-name 'ReSimNet7.mdl' --rep-idx 2
6 |
--------------------------------------------------------------------------------
/predict_example.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # predict pair scores when given with two input drud_ids.
3 | # calculate prediction scores based on averged scores of all 10 models.
4 | # if you do not want this, set --save-pair-score-ensemble to false
5 | python main.py --save-pair-score true --save-pair-score-ensemble true --pair-dir './tasks/data/pairs/' --fp-dir './tasks/data/pertid2fingerprint.pkl' --data-path './tasks/data/ReSimNet-Dataset.pkl' --model-name 'ReSimNet.mdl' --rep-idx 2
6 |
--------------------------------------------------------------------------------
/predict_zinc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # predict pair scores when given with two input drud_ids.
3 | # calculate prediction scores based on averged scores of all 10 models.
4 | # if you do not want this, set --save-pair-score-ensemble to false
5 | CUDA_VISIBLE_DEVICES=1 python main.py --save-pair-score true --save-pair-score-zinc true --pair-dir './tasks/data/pairs_zinc/zinc-test/' --example-dir './tasks/data/pairs_zinc/example_drugs.csv' --data-path './tasks/data/ReSimNet-Dataset.pkl' --model-name 'ReSimNet7.mdl' --rep-idx 2
6 |
--------------------------------------------------------------------------------
/tasks/drug_run.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import logging
9 | import csv
10 | import os
11 | import pandas as pd
12 |
13 | from scipy.stats import pearsonr
14 | from sklearn.metrics import precision_score, roc_auc_score
15 |
16 | from datetime import datetime
17 | from torch.autograd import Variable
18 | from models.root.utils import *
19 |
20 |
21 | LOGGER = logging.getLogger(__name__)
22 |
23 |
24 | def prob_to_class(prob):
25 | return np.array([float(p >= 0.5) for p in prob])
26 |
27 |
28 | def run_bi(model, loader, dataset, args, metric, train=False):
29 | total_step = 0.0
30 | stats = {'loss':[]}
31 | tar_set = []
32 | pred_set = []
33 | kk_tar_set = []
34 | kk_pred_set = []
35 | ku_tar_set = []
36 | ku_pred_set = []
37 | uu_tar_set = []
38 | uu_pred_set = []
39 | start_time = datetime.now()
40 |
41 | for d_idx, (d1, d1_r, d1_l, d2, d2_r, d2_l, score) in enumerate(loader):
42 |
43 | # Split for KK/KU/UU sets
44 | kk_idx = np.argwhere([a in dataset.known and b in dataset.known
45 | for a, b in zip(d1, d2)]).flatten()
46 | ku_idx = np.argwhere([(a in dataset.known) != (b in dataset.known)
47 | for a, b in zip(d1, d2)]).flatten()
48 | uu_idx = np.argwhere([a not in dataset.known and b not in dataset.known
49 | for a, b in zip(d1, d2)]).flatten()
50 | assert len(kk_idx) + len(ku_idx) + len(uu_idx) == len(d1)
51 |
52 | # Grad zero + mode change
53 | model.optimizer.zero_grad()
54 | if train: model.train(train)
55 | else: model.eval()
56 |
57 | # Get outputs
58 | outputs, embed1, embed2 = model(d1_r.cuda(), d1_l, d2_r.cuda(), d2_l,
59 | None, None)
60 | loss = model.get_loss(outputs, score.cuda())
61 | stats['loss'] += [loss.data[0]]
62 | total_step += 1.0
63 |
64 | # Metrics for binary classification
65 | tmp_tar = score.data.cpu().numpy()
66 | tmp_pred = outputs.data.cpu().numpy()
67 | # tmp_pred = np.array([float(p >= 0.5) for p in tmp_pred[:]])
68 | # print(tmp_tar[:5], tmp_pred[:5])
69 |
70 | # Accumulate for final evaluation
71 | tar_set += list(tmp_tar[:])
72 | pred_set += list(tmp_pred[:])
73 | kk_tar_set += list(tmp_tar[kk_idx])
74 | kk_pred_set += list(tmp_pred[kk_idx])
75 | ku_tar_set += list(tmp_tar[ku_idx])
76 | ku_pred_set += list(tmp_pred[ku_idx])
77 | uu_tar_set += list(tmp_tar[uu_idx])
78 | uu_pred_set += list(tmp_pred[uu_idx])
79 |
80 | # Calculate current f1 scores
81 | f1 = metric(list(tmp_tar[:]), list(prob_to_class(tmp_pred[:])))
82 | f1_kk = metric(list(tmp_tar[kk_idx]), list(prob_to_class(tmp_pred[kk_idx])))
83 | f1_ku = metric(list(tmp_tar[ku_idx]), list(prob_to_class(tmp_pred[ku_idx])))
84 | f1_uu = metric(list(tmp_tar[uu_idx]), list(prob_to_class(tmp_pred[uu_idx])))
85 |
86 | # For binary classification, report f1
87 | _, _, f1, _ = f1
88 | _, _, f1_kk, _ = f1_kk
89 | _, _, f1_ku, _ = f1_ku
90 | _, _, f1_uu, _ = f1_uu
91 |
92 | # Optimize model
93 | if train and not args.save_embed:
94 | loss.backward()
95 | nn.utils.clip_grad_norm(model.get_model_params()[1],
96 | args.grad_max_norm)
97 | model.optimizer.step()
98 |
99 | # Print for print step or at last
100 | if d_idx % args.print_step == 0 or d_idx == (len(loader) - 1):
101 | et = int((datetime.now() - start_time).total_seconds())
102 | _progress = (
103 | '{}/{} | Loss: {:.3f} | Total F1: {:.3f} | '.format(
104 | d_idx + 1, len(loader), loss.data[0], f1) +
105 | 'KK: {:.3f} KU: {:.3f} UU: {:.3f} | '.format(
106 | f1_kk, f1_ku, f1_uu) +
107 | '{:2d}:{:2d}:{:2d}'.format(
108 | et//3600, et%3600//60, et%60))
109 | LOGGER.debug(_progress)
110 |
111 | if args.top_only:
112 | # if False:
113 | tar_sets = [tar_set, kk_tar_set, ku_tar_set, uu_tar_set]
114 | pred_sets = [pred_set, kk_pred_set, ku_pred_set, uu_pred_set]
115 | messages = ['Total', 'KK', 'KU', 'UU']
116 | top_criterion = 0.10
117 | top_k = 100
118 |
119 | for tar, pred, msg in zip(tar_sets, pred_sets, messages):
120 | sorted_target = sorted(tar[:], reverse=True)
121 | # top_cut = sorted_target[int(len(sorted_target) * top_criterion)]
122 | top_cut = 0.9
123 |
124 | sorted_pred, my_target = (list(t) for t in zip(*sorted(
125 | zip(pred[:], tar[:]), reverse=True)))
126 | precision = sum(k >= top_cut for k in my_target[:top_k]) / top_k
127 | LOGGER.info('{} cut: {:.3f}, P@{}: {:.2f}, '.format(
128 | msg, top_cut, top_k, precision) +
129 | 'Pred Mean@100: {:.3f}, Tar Mean@100: {:.3f}'.format(
130 | sum(sorted_pred[:top_k])/top_k,
131 | sum(my_target[:top_k])/top_k))
132 |
133 | def sort_and_slice(list1, list2):
134 | list2, list1 = (list(t) for t in zip(*sorted(
135 | zip(list2, list1), reverse=True)))
136 | list1 = list1[:len(list1)//100] + list1[-len(list1)//100:]
137 | # list1 = list1[-len(list1)//100:]
138 | list2 = list2[:len(list2)//100] + list2[-len(list2)//100:]
139 | # list2 = list2[-len(list2)//100:]
140 | assert len(list1) == len(list2)
141 | return list1, list2
142 |
143 | if args.top_only:
144 | # if False:
145 | tar_set, pred_set = sort_and_slice(tar_set, pred_set)
146 | kk_tar_set, kk_pred_set = sort_and_slice(kk_tar_set, kk_pred_set)
147 | ku_tar_set, ku_pred_set = sort_and_slice(ku_tar_set, ku_pred_set)
148 | uu_tar_set, uu_pred_set = sort_and_slice(uu_tar_set, uu_pred_set)
149 |
150 | # Calculate acuumulated f1 scores
151 | f1 = metric(tar_set, prob_to_class(pred_set))
152 | f1_kk = metric(kk_tar_set, prob_to_class(kk_pred_set))
153 | f1_ku = metric(ku_tar_set, prob_to_class(ku_pred_set))
154 | f1_uu = metric(uu_tar_set, prob_to_class(uu_pred_set))
155 | pr, rc, f1, _ = f1
156 | pr_kk, rc_kk, f1_kk, _ = f1_kk
157 | pr_ku, rc_ku, f1_ku, _ = f1_ku
158 | pr_uu, rc_uu, f1_uu, _ = f1_uu
159 |
160 | # TODO add spearman correlation
161 |
162 | # End of an epoch
163 | et = (datetime.now() - start_time).total_seconds()
164 | LOGGER.info('Results (Loss/F1/KK/KU/UU): {:.3f}\t'.format(
165 | sum(stats['loss'])/len(stats['loss'])) +
166 | '[{:.3f}\t{:.3f}\t{:.3f}]\t[{:.3f}\t{:.3f}\t{:.3f}]\t'.format(
167 | pr, rc, f1, pr_kk, rc_kk, f1_kk) +
168 | '[{:.3f}\t{:.3f}\t{:.3f}]\t[{:.3f}\t{:.3f}\t{:.3f}]\t'.format(
169 | pr_ku, rc_ku, f1_ku, pr_uu, rc_uu, f1_uu) +
170 | 'count: {}/{}/{}/{}'.format(
171 | len(pred_set), len(kk_pred_set), len(ku_pred_set), len(uu_pred_set)))
172 |
173 | return f1_ku
174 |
175 |
176 | def element(d):
177 | return [d[k] for k in range(0,len(d))]
178 |
179 |
180 | def run_reg(model, loader, dataset, args, metric, train=False):
181 | total_step = 0.0
182 | stats = {'loss':[]}
183 | tar_set = []
184 | pred_set = []
185 | kk_tar_set = []
186 | kk_pred_set = []
187 | ku_tar_set = []
188 | ku_pred_set = []
189 | uu_tar_set = []
190 | uu_pred_set = []
191 | start_time = datetime.now()
192 |
193 | for d_idx, d in enumerate(loader):
194 | if args.rep_idx == 4:
195 | d1, d1_r, d1_a, d1_l, d2, d2_r, d2_a, d2_l, score = element(d)
196 | else:
197 | d1, d1_r, d1_l, d2, d2_r, d2_l, score = element(d)
198 |
199 | # Split for KK/KU/UU sets
200 | kk_idx = np.argwhere([a in dataset.known and b in dataset.known
201 | for a, b in zip(d1, d2)]).flatten()
202 | ku_idx = np.argwhere([(a in dataset.known) != (b in dataset.known)
203 | for a, b in zip(d1, d2)]).flatten()
204 | uu_idx = np.argwhere([a not in dataset.known and b not in dataset.known
205 | for a, b in zip(d1, d2)]).flatten()
206 | assert len(kk_idx) + len(ku_idx) + len(uu_idx) == len(d1)
207 |
208 | # Grad zero + mode change
209 | model.optimizer.zero_grad()
210 | if train: model.train(train)
211 | else: model.eval()
212 |
213 | # Get outputs
214 | if args.rep_idx == 4:
215 | outputs, embed1, embed2 = model(d1_r.cuda(), d1_l,
216 | d2_r.cuda(), d2_r,
217 | d1_a.cuda(), d2_a.cuda())
218 | else:
219 | outputs, embed1, embed2 = model(d1_r.cuda(), d1_l,
220 | d2_r.cuda(), d2_l,
221 | None, None)
222 | loss = model.get_loss(outputs, score.cuda())
223 | stats['loss'] += [loss.data[0]]
224 | total_step += 1.0
225 |
226 | # Metrics for regression
227 | tmp_tar = score.data.cpu().numpy()
228 | tmp_pred = outputs.data.cpu().numpy()
229 | # print(tmp_tar[:10])
230 |
231 | # Accumulate for final evaluation
232 | tar_set += list(tmp_tar[:])
233 | pred_set += list(tmp_pred[:])
234 | kk_tar_set += list(tmp_tar[kk_idx])
235 | kk_pred_set += list(tmp_pred[kk_idx])
236 | ku_tar_set += list(tmp_tar[ku_idx])
237 | ku_pred_set += list(tmp_pred[ku_idx])
238 | uu_tar_set += list(tmp_tar[uu_idx])
239 | uu_pred_set += list(tmp_pred[uu_idx])
240 |
241 | # Calculate current f1 scores
242 | f1 = metric(list(tmp_tar[:]), list(tmp_pred[:]))
243 | f1_kk = metric(list(tmp_tar[kk_idx]), list(tmp_pred[kk_idx]))
244 | f1_ku = metric(list(tmp_tar[ku_idx]), list(tmp_pred[ku_idx]))
245 | f1_uu = metric(list(tmp_tar[uu_idx]), list(tmp_pred[uu_idx]))
246 | f1 = f1[0][1]
247 | f1_kk = f1_kk[0][1]
248 | f1_ku = f1_ku[0][1]
249 | f1_uu = f1_uu[0][1]
250 |
251 | # Optimize model
252 | if train and not args.save_embed:
253 | loss.backward()
254 | nn.utils.clip_grad_norm(model.get_model_params()[1],
255 | args.grad_max_norm)
256 | model.optimizer.step()
257 |
258 | # Print for print step or at last
259 | if d_idx % args.print_step == 0 or d_idx == (len(loader) - 1):
260 | et = int((datetime.now() - start_time).total_seconds())
261 | _progress = (
262 | '{}/{} | Loss: {:.3f} | Total Corr: {:.3f} | '.format(
263 | d_idx + 1, len(loader), loss.data[0], f1) +
264 | 'KK: {:.3f} KU: {:.3f} UU: {:.3f} | '.format(
265 | f1_kk, f1_ku, f1_uu) +
266 | '{:2d}:{:2d}:{:2d}'.format(
267 | et//3600, et%3600//60, et%60))
268 | LOGGER.debug(_progress)
269 |
270 | # if args.top_only:
271 | # # if False:
272 | # tar_sets = [tar_set, kk_tar_set, ku_tar_set, uu_tar_set]
273 | # pred_sets = [pred_set, kk_pred_set, ku_pred_set, uu_pred_set]
274 | # messages = ['Total', 'KK', 'KU', 'UU']
275 | # top_criterion = 0.10
276 | # top_k = 100
277 | #
278 | # for tar, pred, msg in zip(tar_sets, pred_sets, messages):
279 | # sorted_target = sorted(tar[:], reverse=True)
280 | # # top_cut = sorted_target[int(len(sorted_target) * top_criterion)]
281 | # top_cut = 0.9
282 | #
283 | # sorted_pred, my_target = (list(t) for t in zip(*sorted(
284 | # zip(pred[:], tar[:]), reverse=True)))
285 | # precision = sum(k >= top_cut for k in my_target[:top_k]) / top_k
286 | # LOGGER.info('{} cut: {:.3f}, P@{}: {:.2f}, '.format(
287 | # msg, top_cut, top_k, precision) +
288 | # 'Pred Mean@100: {:.3f}, Tar Mean@100: {:.3f}'.format(
289 | # sum(sorted_pred[:top_k])/top_k,
290 | # sum(my_target[:top_k])/top_k))
291 | #
292 | # def sort_and_slice(list1, list2):
293 | # list2, list1 = (list(t) for t in zip(*sorted(
294 | # zip(list2, list1), reverse=True)))
295 | # list1 = list1[:len(list1)//100] + list1[-len(list1)//100:]
296 | # # list1 = list1[-len(list1)//100:]
297 | # list2 = list2[:len(list2)//100] + list2[-len(list2)//100:]
298 | # # list2 = list2[-len(list2)//100:]
299 | # assert len(list1) == len(list2)
300 | # return list1, list2
301 | #
302 | # if args.top_only:
303 | # # if False:
304 | # tar_set, pred_set = sort_and_slice(tar_set, pred_set)
305 | # kk_tar_set, kk_pred_set = sort_and_slice(kk_tar_set, kk_pred_set)
306 | # ku_tar_set, ku_pred_set = sort_and_slice(ku_tar_set, ku_pred_set)
307 | # uu_tar_set, uu_pred_set = sort_and_slice(uu_tar_set, uu_pred_set)
308 |
309 | # Calculate acuumulated f1 scores
310 | f1 = metric(tar_set, pred_set)
311 | f1_kk = metric(kk_tar_set, kk_pred_set)
312 | f1_ku = metric(ku_tar_set, ku_pred_set)
313 | f1_uu = metric(uu_tar_set, uu_pred_set)
314 |
315 | # Trun into correlation
316 | f1 = f1[0][1]
317 | f1_kk = f1_kk[0][1]
318 | f1_ku = f1_ku[0][1]
319 | f1_uu = f1_uu[0][1]
320 |
321 | # End of an epoch
322 | et = (datetime.now() - start_time).total_seconds()
323 | LOGGER.info('Results (Loss/F1/KK/KU/UU): {:.4f}\t'.format(
324 | sum(stats['loss'])/len(stats['loss'])) +
325 | '[{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}] '.format(
326 | f1, f1_kk, f1_ku, f1_uu) +
327 | 'count: {}/{}/{}/{}'.format(
328 | len(pred_set), len(kk_pred_set), len(ku_pred_set), len(uu_pred_set)))
329 |
330 |
331 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(pred_set, tar_set)
332 | LOGGER.info('[TOTAL\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
333 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
334 |
335 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(kk_pred_set, kk_tar_set)
336 | LOGGER.info('[KK\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
337 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
338 |
339 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(ku_pred_set, ku_tar_set)
340 | LOGGER.info('[KU\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
341 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
342 |
343 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(uu_pred_set, uu_tar_set)
344 | LOGGER.info('[UU\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
345 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
346 |
347 | return f1_ku
348 |
349 | def precision_at_k(y_pred, y_true, k):
350 | list_of_tuple = [(x, y) for x, y in zip(y_pred, y_true)]
351 | sorted_list_of_tuple = sorted(list_of_tuple, key=lambda tup: tup[0], reverse=True)
352 | topk = sorted_list_of_tuple[:int(len(sorted_list_of_tuple) * k)]
353 | topk_true = [x[1] for x in topk]
354 | topk_pred = [x[0] for x in topk]
355 | #print(topk)
356 | #print(topk_true)
357 | #print(topk_pred)
358 | precisionk = precision_score([1 if x > 0.9 else 0 for x in topk_true],
359 | [1 if x > -1.0 else 0 for x in topk_pred], labels=[0,1], pos_label=1)
360 | # print([1 if x > 90.0 else 0 for x in topk_true])
361 | # print([1 if x > 90.0 else 0 for x in topk_pred])
362 | # print(precisionk)
363 | return precisionk
364 |
365 | def mse_at_k(y_pred, y_true, k):
366 | list_of_tuple = [(x, y) for x, y in zip(y_pred, y_true)]
367 | sorted_list_of_tuple = sorted(list_of_tuple, key=lambda tup: tup[0], reverse=True)
368 | topk = sorted_list_of_tuple[:int(len(sorted_list_of_tuple) * k)]
369 | topk_true = [x[1] for x in topk]
370 | topk_pred = [x[0] for x in topk]
371 |
372 | msek = np.square(np.subtract(topk_pred, topk_true)).mean()
373 | return msek
374 |
375 | def evaluation(y_pred, y_true):
376 | # print(y_pred)
377 | # print(y_true)
378 | # print(pearsonr(np.ravel(y_pred), y_true))
379 | corr = pearsonr(np.ravel(y_pred), y_true)[0]
380 | # mse = np.square(np.subtract(y_pred, y_true)).mean()
381 | msetotal = mse_at_k(y_pred, y_true, 1.0)
382 | mse1 = mse_at_k(y_pred, y_true, 0.01)
383 | mse2 = mse_at_k(y_pred, y_true, 0.02)
384 | mse5 = mse_at_k(y_pred, y_true, 0.05)
385 |
386 | auroc = float('nan')
387 | if len([x for x in y_true if x > 0.9]) > 0:
388 | auroc = roc_auc_score([1 if x > 0.9 else 0 for x in y_true], y_pred)
389 | precision1 = precision_at_k(y_pred, y_true, 0.01)
390 | precision2 = precision_at_k(y_pred, y_true, 0.02)
391 | precision5 = precision_at_k(y_pred, y_true, 0.05)
392 | #print(auroc, precision1, precision2, precision5)
393 | return (corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5)
394 |
395 |
396 | # Outputs response embeddings for a given dictionary
397 | def save_embed(model, dictionary, dataset, args, drug_file):
398 | model.eval()
399 | key2vec = {}
400 | known_cnt = 0
401 |
402 | # Iterate drug dictionary
403 | for idx, item in enumerate(dictionary.items()):
404 | drug, rep = [item[k] for k in range(0,len(item))]
405 | if args.embed_d == 1:
406 | d1_r = rep[args.rep_idx]
407 | d1_k = drug in dataset.known
408 | d1_l = len(d1_r)
409 | else:
410 | d1_r = rep[0]
411 | d1_k = rep[1]
412 | d1_l = len(d1_r)
413 |
414 | # For string data (smiles/inchikey)
415 | if args.rep_idx == 0 or args.rep_idx == 1:
416 | d1_r = list(map(lambda x: dataset.char2idx[x]
417 | if x in dataset.char2idx
418 | else dataset.char2idx[dataset.UNK], d1_r))
419 | d1_l = len(d1_r)
420 |
421 | # Real valued for mol2vec
422 | if args.rep_idx != 3:
423 | d1_r = Variable(torch.LongTensor(d1_r)).cuda()
424 | else:
425 | d1_r = Variable(torch.FloatTensor(d1_r)).cuda()
426 | d1_l = torch.LongTensor(np.array([d1_l]))
427 | d1_r = d1_r.unsqueeze(0)
428 | d1_l = d1_l.unsqueeze(0)
429 |
430 | # Run model amd save embed
431 | _, embed1, embed2 = model(d1_r, d1_l, d1_r, d1_l, None, None)
432 | assert embed1.data.tolist() == embed2.data.tolist()
433 | """
434 | known = False
435 | for pert_id, _ in dataset.drugs.items():
436 | if drug == pert_id:
437 | known = True
438 | known_cnt += 1
439 | break
440 | """
441 | key2vec[drug] = [embed1.squeeze().data.tolist(), d1_k]
442 |
443 | # Print progress
444 | if idx % args.print_step == 0 or idx == len(dictionary) - 1:
445 | _progress = '{}/{} saving drug embeddings..'.format(
446 | idx + 1, len(dictionary))
447 | LOGGER.info(_progress)
448 |
449 | # Save embed as pickle
450 | pickle.dump(key2vec, open('{}/embed/{}.{}.pkl'.format(
451 | args.checkpoint_dir, drug_file, args.model_name), 'wb'),
452 | protocol=2)
453 | LOGGER.info('{}/{} number of known drugs.'.format(known_cnt, len(key2vec)))
454 |
455 |
456 | # Outputs pred vs label scores given a dataloader
457 | def save_prediction(model, loader, dataset, args):
458 | model.eval()
459 | csv_writer = csv.writer(open(args.checkpoint_dir + 'pred_' +
460 | args.model_name + '.csv', 'w'))
461 | csv_writer.writerow(['pert1', 'pert1_known', 'pert2', 'pert2_known',
462 | 'prediction', 'target'])
463 |
464 | for d_idx, (d1, d1_r, d1_l, d2, d2_r, d2_l, score) in enumerate(loader):
465 |
466 | # Run model for getting predictions
467 | outputs, _, _ = model(d1_r.cuda(), d1_l, d2_r.cuda(), d2_l, None, None)
468 | predictions = outputs.data.cpu().numpy()
469 | targets = score.data.tolist()
470 |
471 | for a1, a2, a3, a4 in zip(d1, d2, predictions, targets):
472 | csv_writer.writerow([a1, a1 in dataset.known,
473 | a2, a2 in dataset.known, a3, a4])
474 |
475 | # Print progress
476 | if d_idx % args.print_step == 0 or d_idx == len(loader) - 1:
477 | _progress = '{}/{} saving drug predictions..'.format(
478 | d_idx + 1, len(loader))
479 | LOGGER.info(_progress)
480 |
481 | # Outputs pred vs label scores given a dataloader
482 | def perform_ensemble(model, loader, dataset, args):
483 | model.eval()
484 | tar_set = []
485 | pred_set = []
486 | kk_tar_set = []
487 | kk_pred_set = []
488 | ku_tar_set = []
489 | ku_pred_set = []
490 | uu_tar_set = []
491 | uu_pred_set = []
492 |
493 | for d_idx, (d1, d1_r, d1_l, d2, d2_r, d2_l, score) in enumerate(loader):
494 | # Run model for getting predictions
495 | outputs, _, _ = model(d1_r.cuda(), d1_l, d2_r.cuda(), d2_l, None, None)
496 |
497 | # Split for KK/KU/UU sets
498 | kk_idx = np.argwhere([a in dataset.known and b in dataset.known
499 | for a, b in zip(d1, d2)]).flatten()
500 | ku_idx = np.argwhere([(a in dataset.known) != (b in dataset.known)
501 | for a, b in zip(d1, d2)]).flatten()
502 | uu_idx = np.argwhere([a not in dataset.known and b not in dataset.known
503 | for a, b in zip(d1, d2)]).flatten()
504 | assert len(kk_idx) + len(ku_idx) + len(uu_idx) == len(d1)
505 |
506 | # Metrics for regression
507 | tmp_tar = score.data.cpu().numpy()
508 | tmp_pred = outputs.data.cpu().numpy()
509 |
510 | # Accumulate for final evaluation
511 | tar_set += list(tmp_tar[:])
512 | pred_set += list(tmp_pred[:])
513 | kk_tar_set += list(tmp_tar[kk_idx])
514 | kk_pred_set += list(tmp_pred[kk_idx])
515 | ku_tar_set += list(tmp_tar[ku_idx])
516 | ku_pred_set += list(tmp_pred[ku_idx])
517 | uu_tar_set += list(tmp_tar[uu_idx])
518 | uu_pred_set += list(tmp_pred[uu_idx])
519 |
520 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(pred_set, tar_set)
521 | print('[TOTAL\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
522 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
523 |
524 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(kk_pred_set, kk_tar_set)
525 | print('[KK\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
526 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
527 |
528 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(ku_pred_set, ku_tar_set)
529 | print('[KU\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
530 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
531 |
532 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5 = evaluation(uu_pred_set, uu_tar_set)
533 | print('[UU\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}] '.format(
534 | corr, msetotal, mse1, mse2, mse5, auroc, precision1, precision2, precision5))
535 |
536 | return pred_set, tar_set, kk_pred_set, kk_tar_set, ku_pred_set, ku_tar_set, uu_pred_set, uu_tar_set
537 |
538 | # Outputs pred scores for new pair dataset
539 | def save_pair_score(model, pair_dir, fp_dir, dataset, args):
540 | model.eval()
541 | drug2rep = pickle.load(open(fp_dir, 'rb'))
542 |
543 | folder_name = args.checkpoint_dir + 'save_pair_score/'
544 | if not os.path.exists(folder_name):
545 | os.makedirs(folder_name)
546 |
547 | for subdir, _, files in os.walk(pair_dir):
548 | for file_ in sorted(files):
549 |
550 | df = pd.read_csv(os.path.join(subdir, file_), sep=",")
551 | #print(df)
552 | LOGGER.info('save_pair_score processing {}...'.format(file_))
553 |
554 | csv_writer = csv.writer(open(folder_name + file_ + '_' +
555 | args.model_name + '.csv', 'w'))
556 | csv_writer.writerow(['drug1', 'drug2', 'prediction', 'jaccard'])
557 |
558 | batch = []
559 | for row_idx, row in df.iterrows():
560 | drug1 = row['id1']
561 | drug1_r = drug2rep[drug1][0]
562 | drug1_r = [float(value) for value in list(drug1_r)]
563 |
564 | drug2 = row['id2']
565 | drug2_r = drug2rep[drug2][0]
566 | drug2_r = [float(value) for value in list(drug2_r)]
567 |
568 | example = [drug1, drug1_r, len(drug1_r),
569 | drug2, drug2_r, len(drug2_r), 0]
570 | batch.append(example)
571 |
572 | if len(batch) == 1024:
573 | inputs = dataset.collate_fn(batch)
574 | outputs, _, _ = model(inputs[1].cuda(), inputs[2], inputs[4].cuda(), inputs[5], None, None)
575 | predictions = outputs.data.cpu().numpy()
576 |
577 | for example, pred in zip(batch, predictions):
578 | from scipy.spatial import distance
579 | def jaccard(a, b):
580 | return 1-distance.jaccard(a, b)
581 | jac = jaccard(example[1], example[4])
582 |
583 | csv_writer.writerow([example[0], example[3], pred, jac])
584 | print(example[0], example[3], pred, jac)
585 |
586 | batch = []
587 |
588 | # Print progress
589 | if row_idx % 5000 == 0 or row_idx == len(df) - 1:
590 | _progress = '{}/{} saving unknwon predictions..'.format(
591 | row_idx + 1, len(df))
592 | LOGGER.info(_progress)
593 |
594 | if len(batch) > 0:
595 | inputs = dataset.collate_fn(batch)
596 | outputs, _, _ = model(inputs[1].cuda(), inputs[2], inputs[4].cuda(), inputs[5], None, None)
597 | predictions = outputs.data.cpu().numpy()
598 |
599 | for example, pred in zip(batch, predictions):
600 | from scipy.spatial import distance
601 | def jaccard(a, b):
602 | return 1-distance.jaccard(a, b)
603 | jac = jaccard(example[1], example[4])
604 | csv_writer.writerow([example[0], example[3], pred, jac])
605 |
606 |
607 | def save_pair_score_for_zinc(model, pair_dir, example_dir, dataset, args):
608 | print("\n=============================================================")
609 | print("SAVE PAIR SCORE FOR ZINC")
610 | print("=============================================================")
611 |
612 | model.eval()
613 | df_example = pd.read_csv(example_dir, sep=",")
614 | print(df_example)
615 |
616 | folder_name = args.checkpoint_dir + 'save_pair_score_for_zinc/'
617 | if not os.path.exists(folder_name):
618 | os.makedirs(folder_name)
619 |
620 | for subdir, _, files in os.walk(pair_dir):
621 | for file_ in sorted(files):
622 |
623 | df_zinc = pd.read_csv(os.path.join(subdir, file_), sep=",")
624 | LOGGER.info('save_pair_score processing {}...'.format(file_))
625 | csv_writer = csv.writer(open(folder_name + file_ + '_' +
626 | args.model_name + '.csv', 'w'))
627 | csv_writer.writerow(['pair1', 'pair2', 'prediction'])
628 |
629 | batch = []
630 | for row_idx, row in df_zinc.iterrows():
631 | drug1 = row['zinc_id']
632 | drug1_r = row['fingerprint']
633 | drug1_r = [float(value) for value in list(drug1_r)]
634 |
635 | for row_idex, row in df_example.iterrows():
636 | try:
637 | drug2 = row['pair']
638 | drug2_r =row['fp']
639 | drug2_r = [float(value) for value in list(drug2_r)]
640 | #print(drug1, drug1_r, len(drug1_r), drug2, drug2_r, len(drug2_r))
641 |
642 | example = [drug1, drug1_r, len(drug1_r),
643 | drug2, drug2_r, len(drug2_r), 0]
644 | batch.append(example)
645 | except KeyError:
646 | continue
647 |
648 | if len(batch) == 4096:
649 | inputs = dataset.collate_fn(batch)
650 | outputs, _, _ = model(inputs[1].cuda(), inputs[2], inputs[4].cuda(), inputs[5], None, None)
651 | predictions = outputs.data.cpu().numpy()
652 |
653 | for example, pred in zip(batch, predictions):
654 | if pred > 0.9:
655 | csv_writer.writerow([example[0], example[3], pred])
656 |
657 | batch = []
658 |
659 | # Print progress
660 | if row_idx % 1000 == 0 or row_idx == len(df_zinc) - 1:
661 | _progress = '{}/{} saving zinc predictions..'.format(
662 | row_idx + 1, len(df_zinc))
663 | LOGGER.info(_progress)
664 |
665 | if len(batch) > 0:
666 | inputs = dataset.collate_fn(batch)
667 | outputs, _, _ = model(inputs[1].cuda(), inputs[2], inputs[4].cuda(), inputs[5], None, None)
668 | predictions = outputs.data.cpu().numpy()
669 |
670 | for example, pred in zip(batch, predictions):
671 | if pred > 0.9:
672 | csv_writer.writerow([example[0], example[3], pred])
673 |
--------------------------------------------------------------------------------
/tasks/drug_task.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import copy
4 | import pickle
5 | import string
6 | import os
7 | import random
8 | import csv
9 | import torch
10 | import scipy.sparse as sp
11 |
12 | from os.path import expanduser
13 | from torch.autograd import Variable
14 | from torch.utils.data import Dataset
15 | from torch.utils.data.sampler import Sampler
16 |
17 |
18 | class DrugDataset(object):
19 | def __init__(self, drug_id_path, drug_sub_path, drug_pair_path):
20 |
21 | self.initial_setting()
22 | # Build drug dictionary for id + sub
23 | self.drugs = self.process_drug_id(drug_id_path)
24 | self.append_drug_sub(drug_sub_path, self.drugs)
25 |
26 | # Save drug pair scores
27 | # self.pairs = self.process_drug_pair(drug_pair_path)
28 | self.cell_datasets = self.process_cell_lines(drug_pair_path)
29 | # self.dataset = self.split_dataset(self.pairs)
30 |
31 | def initial_setting(self):
32 | # Dataset split into train/valid/test
33 | self.drugs = {}
34 | self.pairs = []
35 | self.cell_lines = ['MCF7', 'PC3', 'HCC515', 'VCAP',
36 | 'A375', 'HA1E', 'A549', 'HEPG2',
37 | 'HT29', 'SUMMLY']
38 | self.dataset = {'tr': [], 'va': [], 'te': []}
39 | self.SR = [0.7, 0.1, 0.2] # split ratio
40 | self.UR = 0.1 # Unknown ratio
41 | self.input_maxlen = 0
42 |
43 | # Drug dictionaries
44 | self.known = {}
45 | self.unknown = {}
46 |
47 | # Character dictionaries (smiles/inchikey chars)
48 | self.schar2idx = {}
49 | self.idx2schar = {}
50 | self.ichar2idx = {}
51 | self.idx2ichar = {}
52 | self.schar_maxlen = 0
53 | self.ichar_maxlen = 0
54 | self.sub_lens = []
55 | self.PAD = 'PAD'
56 | self.UNK = 'UNK'
57 |
58 | def register_schar(self, char):
59 | if char not in self.schar2idx:
60 | self.schar2idx[char] = len(self.schar2idx)
61 | self.idx2schar[len(self.idx2schar)] = char
62 |
63 | def register_ichar(self, char):
64 | if char not in self.ichar2idx:
65 | self.ichar2idx[char] = len(self.ichar2idx)
66 | self.idx2ichar[len(self.idx2ichar)] = char
67 |
68 | def process_drug_id(self, path):
69 | print('### Drug ID processing {}'.format(path))
70 | PERT_IDX = 1
71 | SMILES_IDX = 4
72 | INCHIKEY_IDX = 5
73 | drugs = {}
74 | self.register_ichar(self.PAD)
75 | self.register_ichar(self.UNK)
76 | self.register_schar(self.PAD)
77 | self.register_schar(self.UNK)
78 |
79 | with open(path) as f:
80 | csv_reader = csv.reader(f)
81 | for row_idx, row in enumerate(csv_reader):
82 | if row_idx == 0:
83 | continue
84 |
85 | # Add to drug dictionary
86 | drug = row[PERT_IDX]
87 | smiles = row[SMILES_IDX]
88 | inchikey = row[INCHIKEY_IDX]
89 | drugs[drug] = [smiles, inchikey]
90 |
91 | # Update drug characters
92 | list(map(lambda x: self.register_schar(x), smiles))
93 | list(map(lambda x: self.register_ichar(x), inchikey))
94 |
95 | # Update max length
96 | self.schar_maxlen = self.schar_maxlen \
97 | if self.schar_maxlen > len(smiles) else len(smiles)
98 | self.ichar_maxlen = self.ichar_maxlen \
99 | if self.ichar_maxlen > len(inchikey) else len(inchikey)
100 |
101 | print('Drug dictionary size {}'.format(len(drugs)))
102 | print('Smiles char size {}'.format(len(self.schar2idx)))
103 | print('Smiles maxlen {}'.format(self.schar_maxlen))
104 | print('Inchikey char size {}'.format(len(self.ichar2idx)))
105 | print('Inchikey maxlen {}\n'.format(self.ichar_maxlen))
106 | return drugs
107 |
108 | def process_cell_lines(self, path):
109 | cell_pairs = pickle.load(open(path, 'rb'))
110 | new_datasets = {}
111 | for cell_line in self.cell_lines:
112 | '''
113 | print('stats of {}'.format(cell_line))
114 | print(len(cell_pairs[cell_line + '_tr']))
115 | print(len(cell_pairs[cell_line + '_va']))
116 | print(len(cell_pairs[cell_line + '_te']))
117 | print()
118 | '''
119 | cell_train = cell_pairs[cell_line + '_tr']
120 | cell_train = [[k[0][0], k[0][1], [k[1]]] for k in cell_train]
121 | cell_valid = cell_pairs[cell_line + '_va']
122 | cell_valid = [[k[0][0], k[0][1], [k[1]]] for k in cell_valid]
123 | cell_test = cell_pairs[cell_line + '_te']
124 | cell_test = [[k[0][0], k[0][1], [k[1]]] for k in cell_test]
125 | new_datasets[cell_line] = {'tr': cell_train,
126 | 'va': cell_valid,
127 | 'te': cell_test}
128 |
129 |
130 | for d1, d2, _ in cell_train:
131 | self.known[d1] = 0
132 | self.known[d2] = 0
133 |
134 | return new_datasets
135 |
136 | def append_drug_sub(self, paths, drugs):
137 | for path in paths:
138 | print('### Drug subID appending {}'.format(path))
139 | drug2rep = pickle.load(open(path, 'rb'))
140 | #Append drug sub id
141 | for drug, rep in drug2rep.items():
142 | if drug not in drugs:
143 | drugs[drug] = [rep]
144 | else:
145 | drugs[drug].append(rep)
146 | self.sub_lens.append(len(rep))
147 |
148 | print('Drug rep size {}\n'.format(self.sub_lens))
149 |
150 | def process_drug_pair(self, path):
151 | print('### Drug pair processing {}'.format(path))
152 | pair_scores = []
153 | REG_IDX = 5
154 | BI_IDX = 5
155 |
156 | with open(path) as f:
157 | csv_reader = csv.reader(f)
158 | for row_idx, row in enumerate(csv_reader):
159 | if row_idx == 0:
160 | print(row)
161 | print('REG: {}, BI: {}'.format(row[REG_IDX], row[BI_IDX]))
162 | continue
163 |
164 | # Save drugs, score (real-valued), target (binary)
165 | drug1 = row[1]
166 | drug2 = row[2]
167 | reg_score = float(row[REG_IDX])
168 | bi_score = float(row[BI_IDX])
169 | assert drug1 in self.drugs and drug2 in self.drugs
170 |
171 | # Save each drug and scores
172 | pair_scores.append([drug1, drug2, [reg_score, bi_score]])
173 |
174 | print('Dataset size {}\n'.format(len(pair_scores)))
175 | return pair_scores
176 |
177 | def split_dataset(self, pair_scores, unk_test=True):
178 | print('### Split dataset')
179 |
180 | # Shuffle drugs dicitonary and split
181 | items = list(self.drugs.items())
182 | random.shuffle(items)
183 | if unk_test:
184 | self.known = dict(items[:int(-len(items) * self.UR)])
185 | self.unknown = dict(items[int(-len(items) * self.UR):])
186 | else:
187 | self.known = dict(items[:])
188 | self.unknown = dict()
189 |
190 | # Unknown check
191 | for unk, _ in self.unknown.items():
192 | assert unk not in self.known
193 |
194 | # Shuffle dataset
195 | random.shuffle(pair_scores)
196 |
197 | # Ready for train/valid/test
198 | train = []
199 | valid = []
200 | test = []
201 | valid_kk = valid_ku = valid_uu = 0
202 | test_kk = test_ku = test_uu = 0
203 |
204 | # If either one is unknown, add to test or valid
205 | for drug1, drug2, scores in pair_scores:
206 | if drug1 in self.unknown or drug2 in self.unknown:
207 | is_test = np.random.binomial(1,
208 | self.SR[2]/(self.SR[1]+self.SR[2]))
209 |
210 | if is_test:
211 | test.append([drug1, drug2, scores])
212 | if drug1 in self.unknown and drug2 in self.unknown:
213 | test_uu += 1
214 | else:
215 | test_ku += 1
216 | else:
217 | valid.append([drug1, drug2, scores])
218 | if drug1 in self.unknown and drug2 in self.unknown:
219 | valid_uu += 1
220 | else:
221 | valid_ku += 1
222 |
223 | # Fill known/known set with limit of split ratio
224 | for drug1, drug2, scores in pair_scores:
225 | if drug1 not in self.unknown and drug2 not in self.unknown:
226 | assert drug1 in self.known and drug2 in self.known
227 |
228 | if len(train) < len(pair_scores) * self.SR[0]:
229 | train.append([drug1, drug2, scores])
230 | elif len(valid) < len(pair_scores) * self.SR[1]:
231 | valid.append([drug1, drug2, scores])
232 | valid_kk += 1
233 | else:
234 | test.append([drug1, drug2, scores])
235 | test_kk += 1
236 |
237 | print('Train/Valid/Test split: {}/{}/{}'.format(
238 | len(train), len(valid), len(test)))
239 | print('Valid/Test KK,KU,UU: ({},{},{})/({},{},{})\n'.format(
240 | valid_kk, valid_ku, valid_uu, test_kk, test_ku, test_uu))
241 |
242 | return {'tr': train, 'va': valid, 'te': test}
243 |
244 | def get_cellloader(self, batch_size=32, shuffle=True, num_workers=5, s_idx=0,
245 | cell_line='PC3'):
246 |
247 | train_dataset = Representation(self.cell_datasets[cell_line]['tr'],
248 | self.drugs,
249 | self._rep_idx, s_idx=s_idx)
250 |
251 | train_sampler = SortedBatchSampler(train_dataset.lengths(),
252 | batch_size,
253 | shuffle=True)
254 |
255 | train_loader = torch.utils.data.DataLoader(
256 | train_dataset,
257 | batch_size=batch_size,
258 | sampler=train_sampler,
259 | num_workers=num_workers,
260 | collate_fn=self.collate_fn,
261 | pin_memory=True,
262 | )
263 |
264 | valid_dataset = Representation(self.cell_datasets[cell_line]['va'],
265 | self.drugs,
266 | self._rep_idx, s_idx=s_idx)
267 | valid_sampler = SortedBatchSampler(valid_dataset.lengths(),
268 | batch_size,
269 | shuffle=False)
270 | valid_loader = torch.utils.data.DataLoader(
271 | valid_dataset,
272 | batch_size=batch_size,
273 | sampler=valid_sampler,
274 | num_workers=num_workers,
275 | collate_fn=self.collate_fn,
276 | pin_memory=True,
277 | shuffle=False,
278 | )
279 |
280 | test_dataset = Representation(self.cell_datasets[cell_line]['te'],
281 | self.drugs,
282 | self._rep_idx, s_idx=s_idx)
283 | test_sampler = SortedBatchSampler(test_dataset.lengths(),
284 | batch_size,
285 | shuffle=False)
286 | test_loader = torch.utils.data.DataLoader(
287 | test_dataset,
288 | batch_size=batch_size,
289 | sampler=test_sampler,
290 | num_workers=num_workers,
291 | collate_fn=self.collate_fn,
292 | pin_memory=True,
293 | shuffle=False,
294 | )
295 |
296 | return train_loader, valid_loader, test_loader
297 |
298 | def collate_fn(self, batch):
299 | drug1_raws = [ex[0] for ex in batch]
300 | drug1_lens = torch.LongTensor([ex[2] for ex in batch])
301 | drug2_raws = [ex[3] for ex in batch]
302 | drug2_lens = torch.LongTensor([ex[5] for ex in batch])
303 |
304 | drug1_maxlen = max([len(ex[1]) for ex in batch])
305 | drug1_reps = torch.FloatTensor(len(batch), drug1_maxlen).zero_()
306 | drug2_maxlen = max([len(ex[4]) for ex in batch])
307 | drug2_reps = torch.FloatTensor(len(batch), drug2_maxlen).zero_()
308 | scores = torch.FloatTensor(len(batch)).zero_()
309 |
310 | for idx, ex in enumerate(batch):
311 | drug1_rep = ex[1]
312 | if self._rep_idx < 2:
313 | drug1_rep = list(map(lambda x: self.char2idx[x]
314 | if x in self.char2idx
315 | else self.char2idx[self.UNK], ex[1]))
316 | drug1_rep = torch.FloatTensor(drug1_rep)
317 | drug1_reps[idx, :drug1_rep.size(0)].copy_(drug1_rep)
318 |
319 | drug2_rep = ex[4]
320 | if self._rep_idx < 2:
321 | drug2_rep = list(map(lambda x: self.char2idx[x]
322 | if x in self.char2idx
323 | else self.char2idx[self.UNK], ex[4]))
324 | drug2_rep = torch.FloatTensor(drug2_rep)
325 | drug2_reps[idx, :drug2_rep.size(0)].copy_(drug2_rep)
326 |
327 | scores[idx] = ex[6]
328 |
329 | # Set to LongTensor if not mol2vec
330 | if self._rep_idx != 3:
331 | drug1_reps = drug1_reps.long()
332 | drug2_reps = drug2_reps.long()
333 |
334 | # Set as Variables
335 | drug1_reps = Variable(drug1_reps)
336 | drug2_reps = Variable(drug2_reps)
337 | scores = Variable(scores)
338 |
339 | return (drug1_raws, drug1_reps, drug1_lens,
340 | drug2_raws, drug2_reps, drug2_lens, scores)
341 |
342 | def get_dataloader(self, batch_size=32, shuffle=True, num_workers=5, s_idx=0):
343 | if self._rep_idx == 4:
344 | train_dataset = Rep_graph(self.dataset['tr'], self.drugs,
345 | s_idx=s_idx)
346 |
347 | train_sampler = SortedBatchSampler(train_dataset.lengths(),
348 | batch_size, shuffle = True)
349 | train_loader = torch.utils.data.DataLoader(
350 | train_dataset,
351 | batch_size = batch_size,
352 | sampler = train_sampler,
353 | num_workers = num_workers,
354 | collate_fn = self.collate_fn_graph,
355 | pin_memory = True,
356 | )
357 |
358 | else:
359 | train_dataset = Representation(self.dataset['tr'], self.drugs,
360 | self._rep_idx, s_idx=s_idx)
361 |
362 | train_sampler = SortedBatchSampler(train_dataset.lengths(),
363 | batch_size,
364 | shuffle=True)
365 |
366 | train_loader = torch.utils.data.DataLoader(
367 | train_dataset,
368 | batch_size=batch_size,
369 | sampler=train_sampler,
370 | num_workers=num_workers,
371 | collate_fn=self.collate_fn,
372 | pin_memory=True,
373 | )
374 | if self._rep_idx == 4:
375 | valid_dataset = Rep_graph(self.dataset['va'], self.drugs,
376 | s_idx = s_idx)
377 |
378 | valid_sampler = SortedBatchSampler(valid_dataset.lengths(),
379 | batch_size,
380 | shuffle=False)
381 |
382 | valid_loader = torch.utils.data.DataLoader(
383 | valid_dataset,
384 | batch_size=batch_size,
385 | sampler=valid_sampler,
386 | num_workers=num_workers,
387 | collate_fn=self.collate_fn_graph,
388 | pin_memory=True,
389 | shuffle=False,
390 | )
391 |
392 |
393 |
394 | else:
395 | valid_dataset = Representation(self.dataset['va'], self.drugs,
396 | self._rep_idx, s_idx=s_idx)
397 | valid_sampler = SortedBatchSampler(valid_dataset.lengths(),
398 | batch_size,
399 | shuffle=False)
400 | valid_loader = torch.utils.data.DataLoader(
401 | valid_dataset,
402 | batch_size=batch_size,
403 | sampler=valid_sampler,
404 | num_workers=num_workers,
405 | collate_fn=self.collate_fn,
406 | pin_memory=True,
407 | shuffle=False,
408 | )
409 |
410 | if self._rep_idx ==4:
411 | test_dataset = Rep_graph(self.dataset['te'], self.drugs,
412 | s_idx = s_idx)
413 |
414 | test_sampler = SortedBatchSampler(test_dataset.lengths(),
415 | batch_size,
416 | shuffle=False)
417 |
418 | test_loader = torch.utils.data.DataLoader(
419 | test_dataset,
420 | batch_size=batch_size,
421 | sampler=test_sampler,
422 | num_workers=num_workers,
423 | collate_fn=self.collate_fn_graph,
424 | pin_memory=True,
425 | shuffle=False,
426 | )
427 |
428 | else:
429 | test_dataset = Representation(self.dataset['te'], self.drugs,
430 | self._rep_idx, s_idx=s_idx)
431 | test_sampler = SortedBatchSampler(test_dataset.lengths(),
432 | batch_size,
433 | shuffle=False)
434 | test_loader = torch.utils.data.DataLoader(
435 | test_dataset,
436 | batch_size=batch_size,
437 | sampler=None,
438 | num_workers=num_workers,
439 | collate_fn=self.collate_fn,
440 | pin_memory=True,
441 | shuffle=False,
442 | )
443 |
444 |
445 | return train_loader, valid_loader, test_loader
446 |
447 | def collate_fn(self, batch):
448 | drug1_raws = [ex[0] for ex in batch]
449 | drug1_lens = torch.LongTensor([ex[2] for ex in batch])
450 | drug2_raws = [ex[3] for ex in batch]
451 | drug2_lens = torch.LongTensor([ex[5] for ex in batch])
452 |
453 | drug1_maxlen = max([len(ex[1]) for ex in batch])
454 | drug1_reps = torch.FloatTensor(len(batch), drug1_maxlen).zero_()
455 | drug2_maxlen = max([len(ex[4]) for ex in batch])
456 | drug2_reps = torch.FloatTensor(len(batch), drug2_maxlen).zero_()
457 | scores = torch.FloatTensor(len(batch)).zero_()
458 |
459 | for idx, ex in enumerate(batch):
460 | drug1_rep = ex[1]
461 | if self._rep_idx < 2:
462 | drug1_rep = list(map(lambda x: self.char2idx[x]
463 | if x in self.char2idx
464 | else self.char2idx[self.UNK], ex[1]))
465 | drug1_rep = torch.FloatTensor(drug1_rep)
466 | drug1_reps[idx, :drug1_rep.size(0)].copy_(drug1_rep)
467 |
468 | drug2_rep = ex[4]
469 | if self._rep_idx < 2:
470 | drug2_rep = list(map(lambda x: self.char2idx[x]
471 | if x in self.char2idx
472 | else self.char2idx[self.UNK], ex[4]))
473 | drug2_rep = torch.FloatTensor(drug2_rep)
474 | drug2_reps[idx, :drug2_rep.size(0)].copy_(drug2_rep)
475 |
476 | scores[idx] = ex[6]
477 |
478 | # Set to LongTensor if not mol2vec
479 | if self._rep_idx != 3:
480 | drug1_reps = drug1_reps.long()
481 | drug2_reps = drug2_reps.long()
482 |
483 | # Set as Variables
484 | drug1_reps = Variable(drug1_reps)
485 | drug2_reps = Variable(drug2_reps)
486 | scores = Variable(scores)
487 |
488 | return (drug1_raws, drug1_reps, drug1_lens,
489 | drug2_raws, drug2_reps, drug2_lens, scores)
490 |
491 | def normalize(self, mx):
492 | rowsum = np.sum(mx, axis=1).astype(float)
493 | rowinvs = []
494 | for idx, x in enumerate(rowsum):
495 | rowinv = 1/x if x != 0 else 0
496 | rowinvs.append(rowinv)
497 | r_mat_inv = np.diag(rowinvs)
498 | mx = r_mat_inv.dot(mx)
499 | return mx
500 |
501 | def collate_fn_graph(self, batch):
502 | drug1_raws = [ex[0] for ex in batch]
503 | drug1_lens = torch.LongTensor([ex[3] for ex in batch]) #num_node
504 | drug2_raws = [ex[4] for ex in batch]
505 | drug2_lens = torch.LongTensor([ex[7] for ex in batch])
506 |
507 | drug1_maxlen = max([len(ex[1]) for ex in batch])
508 | drug1_feature_len = max([len(ex[1][1]) for ex in batch])
509 | drug1_features = torch.FloatTensor(len(batch), drug1_maxlen, drug1_feature_len).zero_()
510 | drug1_adjs = torch.FloatTensor(len(batch), drug1_maxlen, drug1_maxlen).zero_()
511 |
512 | drug2_maxlen = max([len(ex[5]) for ex in batch])
513 | drug2_feature_len = max([len(ex[5][1]) for ex in batch])
514 | drug2_features = torch.FloatTensor(len(batch), drug2_maxlen, drug2_feature_len).zero_()
515 | drug2_adjs = torch.FloatTensor(len(batch), drug2_maxlen, drug2_maxlen).zero_()
516 | scores = torch.FloatTensor(len(batch)).zero_()
517 |
518 | for idx, ex in enumerate(batch):
519 | drug1_feature = np.array(ex[1])
520 | #drug1_feature = self.normalize(np.array(ex[1]))
521 | drug1_adj = ex[2]
522 | drug1_feature = torch.FloatTensor(drug1_feature)
523 | drug1_adj = np.array(drug1_adj)
524 | drug1_adj = drug1_adj + np.eye(len(drug1_adj))
525 | #drug1_adj = self.normalize(drug1_adj + np.eye(len(drug1_adj)))
526 | if len(drug1_adj) < drug1_maxlen:
527 | pad_length = drug1_maxlen - len(drug1_adj)
528 | pad = np.zeros((len(drug1_adj), pad_length))
529 | drug1_adj = np.concatenate((drug1_adj, pad), axis=1)
530 | drug1_adj = torch.FloatTensor(drug1_adj)
531 | drug1_features[idx, :drug1_feature.size(0)].copy_(drug1_feature)
532 | drug1_adjs[idx, :drug1_adj.size(0)].copy_(drug1_adj)
533 |
534 | #drug2_feature = self.normalize(np.array(ex[5]))
535 | drug2_feature = np.array(ex[5])
536 | drug2_adj = ex[6]
537 | drug2_feature = torch.FloatTensor(drug2_feature)
538 | drug2_adj = np.array(drug2_adj) + np.eye(len(drug2_adj))
539 | #drug2_adj = self.normalize(np.array(drug2_adj)+ np.eye(len(drug2_adj)))
540 |
541 | if len(drug2_adj) < drug2_maxlen:
542 | pad_length = drug2_maxlen - len(drug2_adj)
543 | pad = np.zeros((len(drug2_adj), pad_length))
544 | drug2_adj = np.concatenate((drug2_adj,pad), axis=1)
545 | drug2_adj = torch.FloatTensor(drug2_adj)
546 | drug2_features[idx, :drug2_feature.size(0)].copy_(drug2_feature)
547 | drug2_adjs[idx, :drug2_adj.size(0)].copy_(drug2_adj)
548 | scores[idx] = ex[8]
549 |
550 | drug1_features = Variable(drug1_features)
551 | drug1_adjs = Variable(drug1_adjs)
552 | drug2_features = Variable(drug2_features)
553 | drug2_adjs = Variable(drug2_adjs)
554 | scores = Variable(scores)
555 |
556 | return (drug1_raws, drug1_features, drug1_adjs, drug1_lens,
557 | drug2_raws, drug2_features, drug2_adjs, drug2_lens,
558 | scores)
559 |
560 |
561 | def decode_data(self, d1, d1_l, d2, d2_l, score):
562 | d1 = d1.data.tolist()
563 | d2 = d2.data.tolist()
564 | if self._rep_idx >= 2:
565 | print('Drug1: {}, length: {}'.format(d1, d1_l))
566 | print('Drug2: {}, length: {}'.format(d2, d2_l))
567 | else:
568 | print('Drug1: {}, length: {}'.format(''.join(list(map(
569 | lambda x: self.idx2char[x], d1[:d1_l]))), d1_l))
570 | print('Drug2: {}, length: {}'.format(''.join(list(map(
571 | lambda x: self.idx2char[x], d2[:d2_l]))), d2_l))
572 | # print('Drug1: {}'.format(d1))
573 | # print('Drug2: {}'.format(d2))
574 | print('Score: {}\n'.format(score.data[0]))
575 |
576 | def decode_data_graph(self, d1_f, d1_a, d1_l, d2_f, d2_a, d2_l, score):
577 | d1_a = d1_a[0:d1_l*d1_l]
578 | d2_a = d2_a[0:d2_l*d2_l]
579 |
580 | print('Drug1 : {} \n adj : {} \n num_node: {}'.format(d1_f, d1_a, d1_l))
581 | print('Drug2 : {} \n adj : {} \n num_node: {}'.format(d2_f, d2_a, d2_l))
582 | print('Score : {} \n'.format(score.data[0]))
583 |
584 | # rep_idx [0, 1, 2, 3]
585 | def set_rep(self, rep_idx):
586 | self._rep_idx = rep_idx
587 |
588 | @property
589 | def char2idx(self):
590 | if self._rep_idx == 0:
591 | return self.schar2idx
592 | elif self._rep_idx == 1:
593 | return self.ichar2idx
594 | else:
595 | return {}
596 |
597 | @property
598 | def idx2char(self):
599 | if self._rep_idx == 0:
600 | return self.idx2schar
601 | elif self._rep_idx == 1:
602 | return self.idx2ichar
603 | else:
604 | return {}
605 |
606 | @property
607 | def char_maxlen(self):
608 | if self._rep_idx == 0:
609 | return self.schar_maxlen
610 | elif self._rep_idx == 1:
611 | return self.ichar_maxlen
612 | else:
613 | return 0
614 |
615 | @property
616 | def input_dim(self):
617 | if self._rep_idx == 0:
618 | return len(self.idx2schar)
619 | elif self._rep_idx == 1:
620 | return len(self.idx2ichar)
621 | elif self._rep_idx == 2:
622 | return 2048
623 | elif self._rep_idx == 3:
624 | return 300
625 | elif self._rep_idx == 4:
626 | return 300
627 | else:
628 | assert False, 'Wrong rep_idx {}'.format(rep_idx)
629 |
630 |
631 | class Representation(Dataset):
632 | def __init__(self, examples, drugs, rep_idx, s_idx):
633 | self.examples = examples
634 | self.drugs = drugs
635 | self.rep_idx = rep_idx
636 | self.s_idx = s_idx
637 |
638 | def __len__(self):
639 | return len(self.examples)
640 |
641 | def __getitem__(self, index):
642 | example = self.examples[index]
643 | next_idx = index
644 | while (self.drugs[example[0]][self.rep_idx] == 'None' or
645 | self.drugs[example[1]][self.rep_idx] == 'None'):
646 | next_idx = (next_idx + 1) % len(self.examples)
647 | example = self.examples[next_idx]
648 | drug1, drug2, scores = example
649 |
650 | # Choose drug representation
651 | drug1_rep = self.drugs[drug1][self.rep_idx]
652 | drug1_len = len(drug1_rep)
653 | drug2_rep = self.drugs[drug2][self.rep_idx]
654 | drug2_len = len(drug2_rep)
655 |
656 | # Inchi None check
657 | if self.rep_idx == 1:
658 | assert drug1_rep != 'None' and drug2_rep != 'None'
659 |
660 | # s_idx == 1 means binary classification
661 | score = scores[self.s_idx]
662 | if self.s_idx == 1:
663 | score = float(score >= 90)
664 | else:
665 | score = score / 100.
666 | return drug1, drug1_rep, drug1_len, drug2, drug2_rep, drug2_len, score
667 |
668 | def lengths(self):
669 | def get_longer_length(ex):
670 | drug1_len = len(self.drugs[ex[0]][self.rep_idx])
671 | drug2_len = len(self.drugs[ex[1]][self.rep_idx])
672 | length = drug1_len if drug1_len > drug2_len else drug2_len
673 | return [length, drug1_len, drug2_len]
674 | return [get_longer_length(ex) for ex in self.examples]
675 |
676 |
677 | class SortedBatchSampler(Sampler):
678 | def __init__(self, lengths, batch_size, shuffle=True):
679 | self.lengths = lengths
680 | self.batch_size = batch_size
681 | self.shuffle = shuffle
682 |
683 | def __iter__(self):
684 | lengths = np.array(
685 | [(l1, l2, l3, np.random.random()) for l1, l2, l3 in self.lengths],
686 | dtype=[('l1', np.int_), ('l2', np.int_), ('l3', np.int_),
687 | ('rand', np.float_)]
688 | )
689 | indices = np.argsort(lengths, order=('l1', 'rand'))
690 | batches = [indices[i:i + self.batch_size]
691 | for i in range(0, len(indices), self.batch_size)]
692 | if self.shuffle:
693 | np.random.shuffle(batches)
694 | return iter([i for batch in batches for i in batch])
695 |
696 | def __len__(self):
697 | return len(self.lengths)
698 |
699 | class Rep_graph(Dataset):
700 | def __init__(self, examples, drugs, s_idx):
701 | self.examples = examples
702 | self.drugs = drugs
703 | self.s_idx = s_idx
704 | self.rep_idx = 4
705 |
706 | def __len__(self):
707 | return len(self.examples)
708 |
709 | def __getitem__(self, index):
710 | #data form : (feature matrix, adjacent matrix) = ([node*feature], [node*node])
711 | example = self.examples[index]
712 | next_idx = index
713 | while (self.drugs[example[0]][self.rep_idx] == 'None' or
714 | self.drugs[example[1]][self.rep_idx] == 'None'):
715 | next_idx = (next_idx + 1) % len(self.examples)
716 | example = self.examples[next_idx]
717 | drug1, drug2, scores = example
718 |
719 | drug1_feature = self.drugs[drug1][self.rep_idx][0]
720 | drug1_adj = self.drugs[drug1][self.rep_idx][1]
721 | drug1_node = len(drug1_feature)
722 | drug2_feature = self.drugs[drug2][self.rep_idx][0]
723 | drug2_adj = self.drugs[drug2][self.rep_idx][1]
724 | drug2_node = len(drug2_feature)
725 |
726 | score = scores[self.s_idx]
727 | if self.s_idx == 1:
728 | score = float(score > 0)
729 | else:
730 | score = score/100
731 |
732 | return (drug1, drug1_feature, drug1_adj, drug1_node,
733 | drug2, drug2_feature, drug2_adj, drug2_node, score)
734 |
735 | def lengths(self):
736 | def get_longer_length(ex):
737 | drug1_len = len(self.drugs[ex[0]][self.rep_idx][0])
738 | drug2_len = len(self.drugs[ex[1]][self.rep_idx][1])
739 | length = drug1_len if drug1_len > drug2_len else drug2_len
740 | return [length, drug1_len, drug2_len]
741 | return [get_longer_length(ex) for ex in self.examples]
742 | """
743 | [Version Note]
744 | v0.1: basic implementation
745 | key A, key B, char: 9165/5677/27
746 | key set: 20337
747 | train:
748 | valid:
749 | test:
750 | v0.2: unknown / known split
751 | v0.3: append sub ids
752 |
753 |
754 | drug_info_1.0.csv
755 | - (drug_id, smiles, inchikey, target)
756 |
757 | drug_cscore_pair_top1%bottom1%.csv
758 | - (drug_id1, drug_id2, score, class)
759 |
760 | drug_fingerprint_1.0_p3.pkl
761 | - (drug_id, fingerprint)
762 |
763 | drug_mol2vec_1.0_p3.pkl
764 | - (drug_id, mol2vec)
765 |
766 | """
767 |
768 | def init_seed(seed=None):
769 | if seed is None:
770 | seed = int(round(time.time() * 1000)) % 10000
771 |
772 | np.random.seed(seed)
773 | torch.manual_seed(seed)
774 | random.seed(seed)
775 |
776 |
777 | if __name__ == '__main__':
778 | init_seed(1004)
779 |
780 | # Dataset configuration
781 | drug_id_path = './data/drug/drug_info_2.0.csv'
782 | drug_sub_path = ['./data/drug/drug_fingerprint_2.0_p2.pkl',
783 | './data/drug/drug_mol2vec_2.0_p2.pkl', ]
784 | # './data/drug/drug_2.0_graph_features.pkl']
785 | # drug_pair_path = './data/drug/drug_cscore_pair_0.7.csv'
786 | drug_pair_path = './data/drug/cell_lines_pair_0.6.pkl'
787 | save_preprocess = True
788 | save_path = './data/drug/drug(tmp).pkl'
789 | load_path = './data/drug/drug(v0.1_graph).pkl'
790 |
791 | # Save or load dataset
792 | if save_preprocess:
793 | dataset = DrugDataset(drug_id_path, drug_sub_path, drug_pair_path)
794 | pickle.dump(dataset, open(save_path, 'wb'))
795 | print('## Save preprocess %s' % save_path)
796 | else:
797 | print('## Load preprocess %s' % load_path)
798 | dataset = pickle.load(open(load_path, 'rb'))
799 |
800 | # Loader testing
801 | dataset.set_rep(rep_idx=1)
802 | graph = False
803 | if graph:
804 | for idx,(d1, d1_f, d1_a, d1_l, d2, d2_f, d2_a, d2_l, score) in enumerate(
805 | dataset.get_dataloader(batch_size = 1600, s_idx = 0)[1]):
806 | dataset.decode_data_graph(d1_f[0], d1_a[0], d1_l[0], d2_f[0], d2_a[0], d2_l[0], score[0])
807 | pass
808 | else:
809 | '''
810 | for idx, (d1, d1_r, d1_l, d2, d2_r, d2_l, score) in enumerate(
811 | dataset.get_dataloader(batch_size=1600, s_idx=1)[1]):
812 | dataset.decode_data(d1_r[0], d1_l[0], d2_r[0], d2_l[0], score[0])
813 | pass
814 | '''
815 | for cell in dataset.cell_lines:
816 | print('cell line {}'.format(cell))
817 | for idx, (d1, d1_r, d1_l, d2, d2_r, d2_l, score) in enumerate(
818 | dataset.get_cellloader(batch_size=3600, s_idx=0, cell_line=cell)[1]):
819 | dataset.decode_data(d1_r[0], d1_l[0], d2_r[0], d2_l[0], score[0])
820 | pass
821 |
--------------------------------------------------------------------------------
/tasks/plot.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from tqdm import tqdm_notebook as tqdm\n",
10 | "import pandas as pd\n",
11 | "import json\n",
12 | "import itertools\n",
13 | "import pickle\n",
14 | "from sklearn import preprocessing\n",
15 | "import matplotlib.pyplot as plt"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 38,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "df_fda = pd.read_csv(\"./data/clue_to_fda.csv\", sep =\",\", header=None)\n",
25 | "df_info = pd.read_csv(\"./data/drug_info_2.0.csv\", sep =\",\")\n",
26 | "\n",
27 | "with open(\"./data/99_our_v0.6_py2.pkl\", 'rb') as f:\n",
28 | " df_embed = pickle.load(f)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 18,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "df_fda = df_fda[df_fda[3] == 'fda']\n",
38 | "fda_list = df_fda[1].unique()"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 21,
44 | "metadata": {
45 | "scrolled": true
46 | },
47 | "outputs": [
48 | {
49 | "data": {
50 | "text/html": [
51 | "
\n", 69 | " | Unnamed: 0 | \n", 70 | "pert_id | \n", 71 | "pert_iname | \n", 72 | "pubchem_id | \n", 73 | "canonical_smiles | \n", 74 | "inchikey | \n", 75 | "moa | \n", 76 | "gene_targets | \n", 77 | "
---|---|---|---|---|---|---|---|---|
0 | \n", 82 | "0 | \n", 83 | "BRD-K32795028 | \n", 84 | "1-benzylimidazole | \n", 85 | "77918.0 | \n", 86 | "c1ccc(Cn2ccnc2)cc1 | \n", 87 | "KKKDZZRICRFGSD-UHFFFAOYSA-N | \n", 88 | "Thromboxane synthase inhibitor | \n", 89 | "NaN | \n", 90 | "
1 | \n", 93 | "1 | \n", 94 | "BRD-A80928489 | \n", 95 | "1-monopalmitin | \n", 96 | "14900.0 | \n", 97 | "CCCCCCCCCCCCCCCC(=O)OCC(O)CO | \n", 98 | "QHZLMUACJMDIAE-UHFFFAOYSA-N | \n", 99 | "P-glycoprotein inhibitor | \n", 100 | "ABCB1 | \n", 101 | "
2 | \n", 104 | "2 | \n", 105 | "BRD-K31491153 | \n", 106 | "1-phenylbiguanide | \n", 107 | "4780.0 | \n", 108 | "N=C(N)NC(=N)Nc1ccccc1 | \n", 109 | "CUQCMXFWIMOWRP-UHFFFAOYSA-N | \n", 110 | "Serotonin receptor agonist | \n", 111 | "NaN | \n", 112 | "
3 | \n", 115 | "3 | \n", 116 | "BRD-K74430258 | \n", 117 | "1,2-dichlorobenzene | \n", 118 | "9600000.0 | \n", 119 | "Clc1cccc(C=NC=Nc2cccc(Cl)c2)c1 | \n", 120 | "IVDUVYRBLGSJAO-UHFFFAOYSA-N | \n", 121 | "Hepatotoxicant|Organic compound | \n", 122 | "NaN | \n", 123 | "
4 | \n", 126 | "4 | \n", 127 | "BRD-K06817181 | \n", 128 | "1,2,3,4,5,6-hexabromocyclohexane | \n", 129 | "74603.0 | \n", 130 | "BrC1C(Br)C(Br)C(Br)C(Br)C1Br | \n", 131 | "QFQZKISCBJKVHI-UHFFFAOYSA-N | \n", 132 | "JAK inhibitor | \n", 133 | "JAK2 | \n", 134 | "
5 | \n", 137 | "5 | \n", 138 | "BRD-K63784565 | \n", 139 | "10-hydroxycamptothecin | \n", 140 | "97226.0 | \n", 141 | "CC[C@@]1(O)C(=O)OCc2c1cc1n(c2=O)Cc2cc3cc(O)ccc... | \n", 142 | "HAWSQZCWOQZXHI-FQEVSTJZSA-N | \n", 143 | "Topoisomerase inhibitor | \n", 144 | "TOP1 | \n", 145 | "
6 | \n", 148 | "6 | \n", 149 | "BRD-K94919853 | \n", 150 | "10H-phenothiazin-10-yl)(p-tolyl)methanone | \n", 151 | "3470000.0 | \n", 152 | "Cc1ccc(C(=O)N2c3ccccc3Sc3ccccc32)cc1 | \n", 153 | "BXCIQPXSBSOIQX-UHFFFAOYSA-N | \n", 154 | "Butyrylcholinesterase inhibitor | \n", 155 | "NaN | \n", 156 | "
7 | \n", 159 | "7 | \n", 160 | "BRD-K05396879 | \n", 161 | "15-delta-prostaglandin-j2 | \n", 162 | "5310000.0 | \n", 163 | "CCCCC/C=C/C=C1/C(=O)C=C[C@@H]1C/C=C\\CCCC(=O)O | \n", 164 | "VHRUMKCAEVRUBK-GODQJPCRSA-N | \n", 165 | "PPAR receptor agonist | \n", 166 | "PPARG | \n", 167 | "
8 | \n", 170 | "8 | \n", 171 | "BRD-K92301463 | \n", 172 | "16,16-dimethylprostaglandin-e2 | \n", 173 | "NaN | \n", 174 | "CCCCC(C)(C)[C@H](O)/C=C/[C@H]1[C@H](O)CC(=O)[C... | \n", 175 | "QAOBBBBDJSWHMU-WMBBNPMCSA-N | \n", 176 | "Prostanoid receptor agonist | \n", 177 | "HPGD | \n", 178 | "
9 | \n", 181 | "9 | \n", 182 | "BRD-K66766661 | \n", 183 | "17-beta-estradiol | \n", 184 | "13791.0 | \n", 185 | "CCCCC(=O)O[C@H]1CC[C@H]2[C@@H]3CCc4cc(O)ccc4[C... | \n", 186 | "RSEPBGGWRJCQGY-RBRWEJTLSA-N | \n", 187 | "Estrogen receptor agonist | \n", 188 | "ESR1 | \n", 189 | "
10 | \n", 192 | "10 | \n", 193 | "BRD-A29731977 | \n", 194 | "17-hydroxyprogesterone-caproate | \n", 195 | "NaN | \n", 196 | "CCCCCC(=O)O[C@]1(C(C)=O)CCC2C3CCC4=CC(=O)CC[C@... | \n", 197 | "DOMWKUIIPQCAJU-JKPPDDDBSA-N | \n", 198 | "Progesterone receptor agonist | \n", 199 | "PGR|AR|ESR1|ESR2|NR1H4|NR1I2|NR3C1 | \n", 200 | "
11 | \n", 203 | "11 | \n", 204 | "BRD-K48115423 | \n", 205 | "2-(4-methoxybenzylthio)-6-methylpyrimidin-4-ol | \n", 206 | "716696.0 | \n", 207 | "COc1ccc(CSc2nc(C)cc(O)n2)cc1 | \n", 208 | "XMFWHLZTTUUQHL-UHFFFAOYSA-N | \n", 209 | "Matrix metalloprotease inhibitor | \n", 210 | "NaN | \n", 211 | "
12 | \n", 214 | "12 | \n", 215 | "BRD-A71657825 | \n", 216 | "2-(biphenyl-4-ylsulfonamido)pentanedioic-acid | \n", 217 | "644601.0 | \n", 218 | "O=C(O)CCC(NS(=O)(=O)c1ccc(-c2ccccc2)cc1)C(=O)O | \n", 219 | "IQGHPUPMOUUHPP-UHFFFAOYSA-N | \n", 220 | "Matrix metalloprotease inhibitor | \n", 221 | "NaN | \n", 222 | "
13 | \n", 225 | "13 | \n", 226 | "BRD-K35128472 | \n", 227 | "2-aminopurine | \n", 228 | "NaN | \n", 229 | "Nc1ncc2nc[nH]c2n1 | \n", 230 | "MWBWWFOAEOYUST-UHFFFAOYSA-N | \n", 231 | "Serine/threonine kinase inhibitor | \n", 232 | "NaN | \n", 233 | "
14 | \n", 236 | "14 | \n", 237 | "BRD-A20131130 | \n", 238 | "2',5'-dideoxyadenosine | \n", 239 | "3044.0 | \n", 240 | "CC1OC(n2cnc3c(N)ncnc32)CC1O | \n", 241 | "FFHPXOJTVQDVMO-UHFFFAOYSA-N | \n", 242 | "Adenylyl cyclase inhibitor | \n", 243 | "NaN | \n", 244 | "
15 | \n", 247 | "15 | \n", 248 | "BRD-K08703257 | \n", 249 | "3-amino-benzamide | \n", 250 | "NaN | \n", 251 | "NC(=O)c1cccc(N)c1 | \n", 252 | "GSCPDZHWVNUUFI-UHFFFAOYSA-N | \n", 253 | "PARP inhibitor | \n", 254 | "NaN | \n", 255 | "
16 | \n", 258 | "16 | \n", 259 | "BRD-A87125127 | \n", 260 | "3-matida | \n", 261 | "10400000.0 | \n", 262 | "Cc1cc(C(=O)O)sc1C(N)C(=O)O | \n", 263 | "KOMWRBFEDDEWEP-UHFFFAOYSA-N | \n", 264 | "Glutamate receptor antagonist | \n", 265 | "GRM1 | \n", 266 | "
17 | \n", 269 | "17 | \n", 270 | "BRD-A62890442 | \n", 271 | "3-methyl-GABA | \n", 272 | "3540000.0 | \n", 273 | "CC(CN)CC(=O)O | \n", 274 | "CZGLBWZXGIAIBU-UHFFFAOYSA-N | \n", 275 | "GABA aminotransferase activator | \n", 276 | "ABAT | \n", 277 | "
18 | \n", 280 | "18 | \n", 281 | "BRD-K37846922 | \n", 282 | "3,3'-diindolylmethane | \n", 283 | "NaN | \n", 284 | "c1ccc2c(Cc3c[nH]c4ccccc34)c[nH]c2c1 | \n", 285 | "VFTRKSBEFQDZKX-UHFFFAOYSA-N | \n", 286 | "CHK inhibitor|Cytochrome P450 activator|Indole... | \n", 287 | "AR|HIF1A|IFNG|PI3 | \n", 288 | "
19 | \n", 291 | "19 | \n", 292 | "BRD-K14643723 | \n", 293 | "4-(2-amino-ethyl)-benzenesulfonamide | \n", 294 | "169682.0 | \n", 295 | "NCCc1ccc(S(N)(=O)=O)cc1 | \n", 296 | "FXNSVEQMUYPYJS-UHFFFAOYSA-N | \n", 297 | "Carbonic anhydrase inhibitor | \n", 298 | "CA2|CA1|CA9 | \n", 299 | "
20 | \n", 302 | "20 | \n", 303 | "BRD-A80383043 | \n", 304 | "4-carboxy-3-hydroxyphenylglycine-(RS) | \n", 305 | "1297.0 | \n", 306 | "NC(C(=O)O)c1ccc(C(=O)O)c(O)c1 | \n", 307 | "GXZSAQLJWLCLOX-UHFFFAOYSA-N | \n", 308 | "Glutamate receptor agonist|Glutamate receptor ... | \n", 309 | "GRM1|GRM5 | \n", 310 | "
21 | \n", 313 | "21 | \n", 314 | "BRD-A15914070 | \n", 315 | "4-hydroxy-2-nonenal | \n", 316 | "5280000.0 | \n", 317 | "CCCCCC(O)/C=C/C=O | \n", 318 | "JVJFIQYAHPMBBX-FNORWQNLSA-N | \n", 319 | "Cytotoxic lipid peroxidation product | \n", 320 | "NaN | \n", 321 | "
22 | \n", 324 | "22 | \n", 325 | "BRD-A96799240 | \n", 326 | "4-hydroxyretinoic-acid | \n", 327 | "6440000.0 | \n", 328 | "CC1=C(/C=C/C(C)=C/C=C/C(C)=C/C(=O)O)C(C)(C)CCC1O | \n", 329 | "KGUMXGDKXYTTEY-FRCNGJHJSA-N | \n", 330 | "Retinoid receptor binder | \n", 331 | "NaN | \n", 332 | "
23 | \n", 335 | "23 | \n", 336 | "BRD-K97118047 | \n", 337 | "4,5,6,7-tetrabromobenzotriazole | \n", 338 | "1694.0 | \n", 339 | "Brc1c(Br)c(Br)c2[nH]nnc2c1Br | \n", 340 | "OMZYUVOATZSGJY-UHFFFAOYSA-N | \n", 341 | "Casein kinase inhibitor | \n", 342 | "CSNK2A1 | \n", 343 | "
24 | \n", 346 | "24 | \n", 347 | "BRD-K57631554 | \n", 348 | "5-aminolevulinic-acid | \n", 349 | "NaN | \n", 350 | "NCC(=O)CCC(=O)O | \n", 351 | "ZGXJTSGNIOSYLO-UHFFFAOYSA-N | \n", 352 | "Oxidizing agent | \n", 353 | "ALAD | \n", 354 | "
25 | \n", 357 | "25 | \n", 358 | "BRD-K34437622 | \n", 359 | "5-FP | \n", 360 | "101498.0 | \n", 361 | "Oc1ncc(F)cn1 | \n", 362 | "HPABFFGQPLJKBP-UHFFFAOYSA-N | \n", 363 | "Thymidylate synthase inhibitor | \n", 364 | "TYMS | \n", 365 | "
26 | \n", 368 | "26 | \n", 369 | "BRD-A18497530 | \n", 370 | "5-iodotubercidin | \n", 371 | "NaN | \n", 372 | "Nc1ncnc2c1c(I)cn2C1OC(CO)C(O)C1O | \n", 373 | "WHSIXKUPQCKWBY-UHFFFAOYSA-N | \n", 374 | "Adenosine kinase inhibitor | \n", 375 | "MAPK3 | \n", 376 | "
27 | \n", 379 | "27 | \n", 380 | "BRD-K30197592 | \n", 381 | "5-methoxytryptamine | \n", 382 | "1833.0 | \n", 383 | "COc1ccc2[nH]cc(CCN)c2c1 | \n", 384 | "JTEJPPKMYBDEMY-UHFFFAOYSA-N | \n", 385 | "Serotonin receptor agonist | \n", 386 | "HTR2A|HTR6 | \n", 387 | "
28 | \n", 390 | "28 | \n", 391 | "BRD-K08219523 | \n", 392 | "5-nonyloxytryptamine | \n", 393 | "1797.0 | \n", 394 | "CCCCCCCCCOc1ccc2[nH]cc(CCN)c2c1 | \n", 395 | "YHSMSRREJYOGQJ-UHFFFAOYSA-N | \n", 396 | "Serotonin receptor agonist | \n", 397 | "NaN | \n", 398 | "
29 | \n", 401 | "29 | \n", 402 | "BRD-A33084410 | \n", 403 | "5'-guanidinonaltrindole | \n", 404 | "4400000.0 | \n", 405 | "N=C(N)Nc1ccc2[nH]c3c(c2c1)CC1(O)C2Cc4ccc(O)c5c... | \n", 406 | "VLNHDKDBGWXJEE-UHFFFAOYSA-N | \n", 407 | "Opioid receptor antagonist | \n", 408 | "OPRK1 | \n", 409 | "
... | \n", 412 | "... | \n", 413 | "... | \n", 414 | "... | \n", 415 | "... | \n", 416 | "... | \n", 417 | "... | \n", 418 | "... | \n", 419 | "... | \n", 420 | "
2398 | \n", 423 | "2398 | \n", 424 | "BRD-K92446736 | \n", 425 | "zatebradine | \n", 426 | "65637.0 | \n", 427 | "COc1ccc(CCN(C)CCCN2CCc3cc(OC)c(OC)cc3CC2=O)cc1OC | \n", 428 | "KEDQCFRVSHYKLR-UHFFFAOYSA-N | \n", 429 | "HCN channel blocker | \n", 430 | "HCN1|HCN2|HCN3|HCN4 | \n", 431 | "
2399 | \n", 434 | "2399 | \n", 435 | "BRD-K64157027 | \n", 436 | "ZD-2079 | \n", 437 | "158794.0 | \n", 438 | "O=C(O)Cc1ccc(OCCNC[C@H](O)c2ccccc2)cc1 | \n", 439 | "SRBPKVWITYPHQR-KRWDZBQOSA-N | \n", 440 | "Adrenergic receptor agonist | \n", 441 | "ADRB3 | \n", 442 | "
2400 | \n", 445 | "2400 | \n", 446 | "BRD-K45296539 | \n", 447 | "ZD-7114 | \n", 448 | "6600000.0 | \n", 449 | "COCCNC(=O)COc1ccc(OCCNC[C@@H](O)COc2ccccc2)cc1 | \n", 450 | "RVMBDLSFFNKKLG-GOSISDBHSA-N | \n", 451 | "Adrenergic receptor agonist | \n", 452 | "ADRB3 | \n", 453 | "
2401 | \n", 456 | "2401 | \n", 457 | "BRD-K11373525 | \n", 458 | "ZD-7155 | \n", 459 | "4670000.0 | \n", 460 | "CCc1cc2c(c(CC)n1)CCC(=O)N2Cc1ccc(-c2ccccc2-c2n... | \n", 461 | "BFVNEYDCFJNLGN-UHFFFAOYSA-N | \n", 462 | "Angiotensin receptor antagonist | \n", 463 | "AGTR1 | \n", 464 | "
2402 | \n", 467 | "2402 | \n", 468 | "BRD-K18678457 | \n", 469 | "ZD-7288 | \n", 470 | "NaN | \n", 471 | "CCN(c1ccccc1)c1cc(NC)[n+](C)c(C)n1 | \n", 472 | "JABSKGQQWUDVRU-UHFFFAOYSA-O | \n", 473 | "HCN channel blocker | \n", 474 | "HCN1|HCN2|HCN3|HCN4 | \n", 475 | "
2403 | \n", 478 | "2403 | \n", 479 | "BRD-A01145011 | \n", 480 | "zebularine | \n", 481 | "46800000.0 | \n", 482 | "O=c1ncccn1[C@@H]1O[C@H](CO)[C@H](O)C1O | \n", 483 | "RPQZTTQVRYEKCR-JJFBUQMESA-N | \n", 484 | "DNA methyltransferase inhibitor | \n", 485 | "CDA|DNMT1 | \n", 486 | "
2404 | \n", 489 | "2404 | \n", 490 | "BRD-A24381660 | \n", 491 | "zeranol | \n", 492 | "216284.0 | \n", 493 | "CC1CCC[C@H](O)CCCCCc2cc(O)cc(O)c2C(=O)O1 | \n", 494 | "DWTTZBARDOXEAM-TYZXPVIJSA-N | \n", 495 | "Estrogen receptor agonist | \n", 496 | "NaN | \n", 497 | "
2405 | \n", 500 | "2405 | \n", 501 | "BRD-U51951544 | \n", 502 | "ZG-10 | \n", 503 | "NaN | \n", 504 | "CN(C)C/C=C/C(=O)Nc1ccc(C(=O)Nc2cccc(Nc3nccc(-c... | \n", 505 | "VZEONOGFXYTZGT-WEVVVXLNSA-N | \n", 506 | "JNK inhibitor | \n", 507 | "IRAK1 | \n", 508 | "
2406 | \n", 511 | "2406 | \n", 512 | "BRD-K31553034 | \n", 513 | "zibotentan | \n", 514 | "9910000.0 | \n", 515 | "COc1nc(C)cnc1NS(=O)(=O)c1cccnc1-c1ccc(-c2nnco2... | \n", 516 | "FJHHZXWJVIEFGJ-UHFFFAOYSA-N | \n", 517 | "Endothelin receptor antagonist | \n", 518 | "EDNRA | \n", 519 | "
2407 | \n", 522 | "2407 | \n", 523 | "BRD-K72903603 | \n", 524 | "zidovudine | \n", 525 | "35370.0 | \n", 526 | "Cc1cn([C@H]2C[C@H](N=[N+]=[N-])[C@@H](CO)O2)c(... | \n", 527 | "HBOMLICNUCNMMY-XLPZGREQSA-N | \n", 528 | "Reverse transcriptase inhibitor | \n", 529 | "NaN | \n", 530 | "
2408 | \n", 533 | "2408 | \n", 534 | "BRD-A56359832 | \n", 535 | "zileuton | \n", 536 | "NaN | \n", 537 | "CC(c1cc2ccccc2s1)N(O)C(N)=O | \n", 538 | "MWLSOWXNZPKENC-UHFFFAOYSA-N | \n", 539 | "Leukotriene inhibitor|Lipoxygenase inhibitor | \n", 540 | "ALOX5 | \n", 541 | "
2409 | \n", 544 | "2409 | \n", 545 | "BRD-K47207162 | \n", 546 | "zimelidine | \n", 547 | "5460000.0 | \n", 548 | "CN(C)C/C=C(\\c1ccc(Br)cc1)c1cccnc1 | \n", 549 | "OYPPVKRFBIWMSX-CXUHLZMHSA-N | \n", 550 | "Serotonin reuptake inhibitor | \n", 551 | "SLC6A4 | \n", 552 | "
2410 | \n", 555 | "2410 | \n", 556 | "BRD-K29582115 | \n", 557 | "ziprasidone | \n", 558 | "NaN | \n", 559 | "O=C1Cc2cc(CCN3CCN(c4nsc5ccccc45)CC3)c(Cl)cc2N1 | \n", 560 | "MVWVFYHBGMAFLY-UHFFFAOYSA-N | \n", 561 | "Dopamine receptor antagonist|Serotonin recepto... | \n", 562 | "DRD2|HTR2A|HTR1A|HTR1D|HRH1|HTR1B|HTR1E|HTR2C|... | \n", 563 | "
2411 | \n", 566 | "2411 | \n", 567 | "BRD-K05151076 | \n", 568 | "ZK-164015 | \n", 569 | "9810000.0 | \n", 570 | "CCCCCS(=O)(=O)CCCCCCCCCCn1c(-c2ccc(O)cc2)c(C)c... | \n", 571 | "LYJSJVYJLZOMCD-UHFFFAOYSA-N | \n", 572 | "Estrogen receptor antagonist | \n", 573 | "NaN | \n", 574 | "
2412 | \n", 577 | "2412 | \n", 578 | "BRD-K56403959 | \n", 579 | "ZK-756326 | \n", 580 | "11700000.0 | \n", 581 | "OCCOCCN1CCN(Cc2cccc(Oc3ccccc3)c2)CC1 | \n", 582 | "SHDFUNGIHDOLQM-UHFFFAOYSA-N | \n", 583 | "CC chemokine receptor ligand | \n", 584 | "CCR8 | \n", 585 | "
2413 | \n", 588 | "2413 | \n", 589 | "BRD-K33882852 | \n", 590 | "ZK-93423 | \n", 591 | "NaN | \n", 592 | "CCOC(=O)c1ncc2[nH]c3ccc(OCc4ccccc4)cc3c2c1COC | \n", 593 | "ALBKMJDFBZVHAK-UHFFFAOYSA-N | \n", 594 | "Benzodiazepine receptor agonist | \n", 595 | "GABRA1|GABRA2|GABRA3|GABRA5 | \n", 596 | "
2414 | \n", 599 | "2414 | \n", 600 | "BRD-K68392338 | \n", 601 | "ZK-93426 | \n", 602 | "NaN | \n", 603 | "CCOC(=O)c1ncc2[nH]c3cccc(OC(C)C)c3c2c1C | \n", 604 | "VMDUABMKBUKKPG-UHFFFAOYSA-N | \n", 605 | "Benzodiazepine receptor antagonist | \n", 606 | "GABRA1 | \n", 607 | "
2415 | \n", 610 | "2415 | \n", 611 | "BRD-K19605405 | \n", 612 | "ZM-241385 | \n", 613 | "176407.0 | \n", 614 | "Nc1nc(NCCc2ccc(O)cc2)nc2nc(-c3ccco3)nn12 | \n", 615 | "PWTBZOIUWZOPFT-UHFFFAOYSA-N | \n", 616 | "Adenosine receptor antagonist | \n", 617 | "ADORA2A|ADORA2B|ADORA1 | \n", 618 | "
2416 | \n", 621 | "2416 | \n", 622 | "BRD-K41337261 | \n", 623 | "ZM-306416 | \n", 624 | "5330000.0 | \n", 625 | "COc1cc2ncnc(Nc3ccc(Cl)cc3F)c2cc1OC | \n", 626 | "YHUIUSRCUKUUQA-UHFFFAOYSA-N | \n", 627 | "Abl kinase inhibitor|Src inhibitor|VEGFR inhib... | \n", 628 | "FLT1|KDR | \n", 629 | "
2417 | \n", 632 | "2417 | \n", 633 | "BRD-K67831364 | \n", 634 | "ZM-323881 | \n", 635 | "NaN | \n", 636 | "Cc1cc(F)c(Nc2ncnc3cc(OCc4ccccc4)ccc23)cc1O | \n", 637 | "NVBNDZZLJRYRPD-UHFFFAOYSA-N | \n", 638 | "VEGFR inhibitor | \n", 639 | "KDR | \n", 640 | "
2418 | \n", 643 | "2418 | \n", 644 | "BRD-K40624912 | \n", 645 | "ZM-39923 | \n", 646 | "176406.0 | \n", 647 | "CC(C)N(CCC(=O)c1ccc2ccccc2c1)Cc1ccccc1 | \n", 648 | "JSASWRWALCMOQP-UHFFFAOYSA-N | \n", 649 | "JAK inhibitor | \n", 650 | "NaN | \n", 651 | "
2419 | \n", 654 | "2419 | \n", 655 | "BRD-K72703948 | \n", 656 | "ZM-447439 | \n", 657 | "NaN | \n", 658 | "COc1cc2c(Nc3ccc(NC(=O)c4ccccc4)cc3)ncnc2cc1OCC... | \n", 659 | "OGNYUTNQZVRGMN-UHFFFAOYSA-N | \n", 660 | "Aurora kinase inhibitor | \n", 661 | "AURKA|AURKB | \n", 662 | "
2420 | \n", 665 | "2420 | \n", 666 | "BRD-K08996725 | \n", 667 | "zolantidine | \n", 668 | "91769.0 | \n", 669 | "c1cc(CN2CCCCC2)cc(OCCCNc2nc3ccccc3s2)c1 | \n", 670 | "KUBONGDXTUOOLM-UHFFFAOYSA-N | \n", 671 | "Histamine receptor antagonist | \n", 672 | "NaN | \n", 673 | "
2421 | \n", 676 | "2421 | \n", 677 | "BRD-K54314721 | \n", 678 | "zolmitriptan | \n", 679 | "NaN | \n", 680 | "CN(C)CCc1c[nH]c2ccc(C[C@H]3COC(=O)N3)cc12 | \n", 681 | "ULSDMUVEXKOYBU-ZDUSSCGKSA-N | \n", 682 | "Serotonin receptor agonist | \n", 683 | "HTR1B|HTR1D|HTR1A|HTR1F | \n", 684 | "
2422 | \n", 687 | "2422 | \n", 688 | "BRD-K44876623 | \n", 689 | "zolpidem | \n", 690 | "NaN | \n", 691 | "Cc1ccc(-c2nc3ccc(C)cn3c2CC(=O)N(C)C)cc1 | \n", 692 | "ZAFYATHCZYHLPB-UHFFFAOYSA-N | \n", 693 | "Benzodiazepine receptor agonist | \n", 694 | "GABRA1|GABRA2|GABRA3 | \n", 695 | "
2423 | \n", 698 | "2423 | \n", 699 | "BRD-K48300629 | \n", 700 | "zonisamide | \n", 701 | "NaN | \n", 702 | "NS(=O)(=O)Cc1noc2ccccc12 | \n", 703 | "UBQNRHZMVUUOMG-UHFFFAOYSA-N | \n", 704 | "Sodium channel blocker|T-type calcium channel ... | \n", 705 | "SCN1A|CA1|CA12|CA7|SCN11A|SCN2A|SCN3A|SCN4A|SC... | \n", 706 | "
2424 | \n", 709 | "2424 | \n", 710 | "BRD-K70557564 | \n", 711 | "zosuquidar | \n", 712 | "3040000.0 | \n", 713 | "O[C@@H](COc1cccc2ncccc12)CN1CCN([C@@H]2c3ccccc... | \n", 714 | "IHOVFYSQUDPMCN-DBEBIPAYSA-N | \n", 715 | "P-glycoprotein inhibitor | \n", 716 | "ABCB1 | \n", 717 | "
2425 | \n", 720 | "2425 | \n", 721 | "BRD-K66353228 | \n", 722 | "zoxazolamine | \n", 723 | "6103.0 | \n", 724 | "Nc1nc2cc(Cl)ccc2o1 | \n", 725 | "YGCODSQDUUUKIV-UHFFFAOYSA-N | \n", 726 | "Myorelaxant | \n", 727 | "NaN | \n", 728 | "
2426 | \n", 731 | "2426 | \n", 732 | "BRD-K63068307 | \n", 733 | "ZSTK-474 | \n", 734 | "11600000.0 | \n", 735 | "FC(F)c1nc2ccccc2n1-c1nc(N2CCOCC2)nc(N2CCOCC2)n1 | \n", 736 | "HGVNLRPZOWWDKD-UHFFFAOYSA-N | \n", 737 | "PI3K inhibitor | \n", 738 | "PIK3CG | \n", 739 | "
2427 | \n", 742 | "2427 | \n", 743 | "BRD-K28761384 | \n", 744 | "zuclopenthixol | \n", 745 | "5310000.0 | \n", 746 | "OCCN1CCN(CC/C=C2/c3ccccc3Sc3ccc(Cl)cc32)CC1 | \n", 747 | "WFPIAZLQTJBIFN-DVZOWYKESA-N | \n", 748 | "Dopamine receptor antagonist | \n", 749 | "DRD2|DRD1 | \n", 750 | "
2428 rows × 8 columns
\n", 754 | "