├── LICENSE ├── README.md ├── chembl_dataset.zip ├── config_examples └── build_vocab_config.yml ├── scripts ├── clear_stereochemistry.py ├── counting_smirks.py ├── dataset_split.py ├── filtering_data.py ├── generating_plots.py ├── mmpdb_prep.py ├── molecular_properties.py ├── new_smirks.py ├── parsing_pairs.py ├── renaming_models.py ├── scaffolding.py ├── scoring.py ├── selfies_to_smiles.py ├── target_specific.py └── training_curves.py ├── slurm_scripts ├── dataset_split.sh ├── model_scoring.sh ├── molecular_properties_histograms.sh ├── new_smirks.sh ├── predictions.sh ├── preparing_data.sh ├── training_curves.sh └── training_model.sh └── target_data_split.zip /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransformMolecules 2 | Machine learning transformer model for generative chemistry. Code written by Emma P. Tysinger and Anton V. Sinitskiy. Paper to be published. 3 | 4 | ## Dependencies 5 | - Install [mmpdb](https://pypi.org/project/mmpdb/) before running preparing_data.sh. For quick installation, run `pip install mmpdb`. 6 | - RDKit is also required for preparing_data.sh and predictions.sh. RDKit is available at http://rdkit.org/. 7 | - Install [OpenNMT-py](https://opennmt.net/OpenNMT-py/main.html#installation) before running training_model.sh and predictions.sh. For installation run: 8 | ``` 9 | git clone https://github.com/OpenNMT/OpenNMT-py.git 10 | cd OpenNMT-py 11 | python setup.py install 12 | ``` 13 | - Install [selfies](https://pypi.org/project/selfies/) before running preparing_data.sh and predictions.sh. For quick installation, run `pip install selfies`. 14 | 15 | ## Usage 16 | **TransformMolecules** contains 5 modules: 17 | - Preparing Data: Generates pairs of molecules, converts molecule representation to selfies 18 | - Dataset Split: Splits data into test/train/validation sets 19 | - Training Model: Trains transformer model on data 20 | - Predictions: Generate molecule predictions, converts molecule representation back to SMILES and evaluates predictions 21 | - Generating Plots 22 | 23 | ### Preparing Data 24 | Run the following commands: 25 | Create a directory `$DATA_DIR` which will store all files related to the input dataset. 26 | ``` 27 | mkdir -p $DATA_DIR 28 | ``` 29 | First prepare the MMPDB input file and remove all stereochemistry and salts. The input to `./scripts/mmpdb_prep.py` (e.g., with ChEMBL data) must be a CSV file with a header (further called as /PATH/data.csv and $DATA), and at least three columns containing SMILES representation of compounds (`$SMI_COL`), ids of compounds (`$ID_COL`) and year of experimentation for the compounds (`year`). `$DATA_ID` refers to the name of the dataset. 30 | ``` 31 | python ./scripts/mmpdb_prep.py --in $DATA --out $DATA_DIR/${DATA_ID}_mmpdb_input.csv --smiles $SMI_COL --ids $ID_COL 32 | python ./scripts/clear_stereochemistry.py --in $DATA_DIR/${DATA_ID}_mmpdb_input.csv --out $DATA_DIR/${DATA_ID}_mmpdb_input_nostereo.csv 33 | ``` 34 | Pair structurally similar molecules using MMPDB. Note that `mmpdb fragment` can take hours to run. 35 | ``` 36 | mmpdb fragment --delimiter comma --has-header $DATA_DIR/${DATA_ID}_mmpdb_input_nostereo.csv -o $DATA_DIR/${DATA_ID}.fragments 37 | mmpdb index $DATA_DIR/${DATA_ID}.fragments -o $DATA_DIR/${DATA_ID}_pairs.csv --out 'csv' 38 | python ./scripts/parsing_pairs.py --in $DATA_DIR/${DATA_ID}_pairs.csv --out $DATA_DIR/${DATA_ID}_pairs_parsed.csv 39 | ``` 40 | Count the number of pairs representing each smirk, which will later be used to filter the data. 41 | ``` 42 | python ./scripts/counting_smirks.py --in $DATA_DIR/${DATA_ID}_pairs_parsed.csv --out $DATA_DIR/${DATA_ID}_counted.csv 43 | ``` 44 | Filter paired dataset by excluding all smirks with a count below a defined threshold ($EXCLUDE) and randomly sample a constant number ($SAMPLE_SIZE) from all included smirks. 45 | ``` 46 | python ./scripts/filtering_data.py --in $DATA_DIR/${DATA_ID}_pairs_parsed.csv --all $DATA --smirks $DATA_DIR/${DATA_ID}_counted.csv --out $DATA_DIR/${DATA_ID}_filtered.csv --size $SAMPLE_SIZE --exclude $EXCLUDE 47 | ``` 48 | It is strongly recommended that you submit this job to a queue using a slurm script, because it may take up to a day to complete. An example of a slurm script can be found in scripts/preparing_data.sh. 49 | ``` 50 | sbatch --export=DATA_ID=test_dataset,DATA_DIR=/PATH/test_dataset,DATA=/PATH/data.csv,SMI_COL='canonical_smiles',ID_COL='chembl_id',SAMPLE_SIZE=3,EXCLUDE=2 ./slurm_scripts/preparing_data.sh 51 | ``` 52 | 53 | ### Dataset Split 54 | Run the following commands, where `$DATA_DIR` is the directory with the paired molecules and `$RUN_DIR` is the directory to create to store all data related to a single dataset split. For year thresholds, all pairs with both molecules discovered before `$TRAIN_YEAR` will be in the training set, pairs with at least one molecule discovered later than `$TEST_YEAR` will be in the test set and all other pairs will be in the validation set. `$TIMESTAMPS` must be a CSV file with a SMILES column and a `years` column. If the `--augment` flag is used, for all pairs `(mol1, mol2)` added to the training set, the reciprocal `(mol2, mol1)` will also be added. 55 | ``` 56 | mkdir -p $RUN_DIR 57 | 58 | ## Dataset split 59 | python ./scripts/dataset_split.py --in $DATA_DIR/${DATA_ID}_filtered.csv --timestamps $TIMESTAMPS --out $RUN_DIR --year_train $TRAIN_YEAR --year_test $TEST_YEAR --augment 60 | ``` 61 | It is strongly recommended that you submit this job to a queue using a slurm script, because it may take up to 5 hours to run. An example of a slurm script can be found in scripts/dataset_split.sh. 62 | ``` 63 | sbatch --export=DATA_ID=test_dataset,DATA_DIR=/PATH/to/paired_dataset,TIMESTAMPS=/PATH/data.csv,RUN_DIR=/PATH/to/training_dataset,TRAIN_YEAR_CUTOFF=2009,VAL_YEAR_CUTOFF=2014 ./slurm_scripts/dataset_split.sh 64 | ``` 65 | 66 | ### Training Model 67 | Install [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) as described above (section Dependencies). 68 | 69 | Run the following commands: 70 | Create a directory `$MODEL_DIR` which will store models. 71 | ``` 72 | mkdir -p $MODEL_DIR 73 | ``` 74 | Next calculate size of the training set, number of steps per epoch, number of total training steps and how often to save models based on `$BATCH_SIZE`, `$TRAIN_EPOCHS` and `$SAVE_EPOCHS`. 75 | ``` 76 | export TRAIN_SIZE=$(cat $RUN_DIR/src-train.txt | wc -l) 77 | export EPOCH_STEPS=$(($TRAIN_SIZE/$BATCH_SIZE)) 78 | export TRAIN_STEPS=$(($EPOCH_STEPS*$TRAIN_EPOCHS)) 79 | export SAVE_STEPS=$(($EPOCH_STEPS*$SAVE_EPOCHS)) 80 | export VALID_STEPS=$(($TRAIN_STEPS+1)) 81 | ``` 82 | Build the config yaml file for training parameters where `$DATA_DIR` is the directory with the paired molecules and `$RUN_DIR` is the directory storing all data related to a single dataset split. 83 | ``` 84 | cat << EOF > $DATA_DIR/config.yaml 85 | ## Where the vocab(s) will be written 86 | src_vocab: $RUN_DIR/vocab.src 87 | tgt_vocab: $RUN_DIR/vocab.tgt 88 | # Corpus opts: 89 | data: 90 | corpus_1: 91 | path_src: $RUN_DIR/src-train.txt 92 | path_tgt: $RUN_DIR/tgt-train.txt 93 | valid: 94 | path_src: $RUN_DIR/src-val.txt 95 | path_tgt: $RUN_DIR/tgt-val.txt 96 | EOF 97 | ``` 98 | Build vocab and start training the transformer model. 99 | ``` 100 | onmt_build_vocab -config $RUN_DIR/config.yaml -save_data $RUN_DIR/data -n_samples -1 101 | onmt_train -config $RUN_DIR/config.yaml -save_model $MODEL_DIR/$MODEL_ID -train_steps $TRAIN_STEPS -valid_steps $VALID_STEPS -save_checkpoint_steps $SAVE_STEPS -batch_size $BATCH_SIZE -world_size 1 -gpu_ranks 0 102 | ``` 103 | Finally rename models to be more intuitive. 104 | ``` 105 | python ./scripts/renaming_models.py --models $MODEL_DIR --batch_size $BATCH_SIZE --train_size $TRAIN_SIZE 106 | ``` 107 | It is strongly recommended that you submit this job to a queue using a slurm script, because it may take multiple days based on the dataset size. An example of a slurm script can be found in scripts/training_model.sh. 108 | ``` 109 | sbatch --export=RUN_DIR=/PATH/to/training_dataset,MODEL_ID=test_dataset,MODEL_DIR=/PATH/model/name_of_run,TRAIN_EPOCHS=50,SAVE_EPOCHS=5,BATCH_SIZE=100 110 | ./slurm_scripts/training_model.sh 111 | ``` 112 | 113 | To resume training from a checkpoint model run the following command, where `$MODEL_PATH` is the path to the checkpoint model. 114 | ``` 115 | onmt_train -config $RUN_DIR/config.yaml -save_model $MODEL_DIR/$MODEL_ID -train_steps $TRAIN_STEPS -valid_steps $TRAIN_STEPS -save_checkpoint_steps $SAVE_STEPS -batch_size $BATCH_SIZE -world_size 1 -gpu_ranks 0 -train_from $MODEL_PATH -reset_optim all 116 | ``` 117 | 118 | To get perplexity scores for data other than validation data, run the following command and look at the `GOLD ppl` score in the error file: 119 | ``` 120 | sbatch --export=MODEL_ID=test_dataset,MODEL_DIR=/PATH/model/name_of_run,SRC_DATA=/PATH/to/input_dataset,TGT_DATA=/PATH/to/true_dataset,EPOCH_NUM=10,OUTPUT_DIR=/PATH/to/predictions ./slurm_scripts/model_scoring.sh 121 | ``` 122 | 123 | ### Predictions 124 | Run the following commands: 125 | Create a txt file with all unique validation molecules and generate new structure predictions with the validation molecules as input to a trained model with `$MODEL_ID` and at `$EPOCH_NUM`. `$RUN_DIR` is the directory storing all data related to a single dataset split, `$MODEL_DIR` is the directory storing the trained model and `$OUTPUT_DIR` is the directory where predictions will be saved. 126 | ``` 127 | mkdir -p $OUTPUT_DIR 128 | cat $RUN_DIR/src-val.txt $RUN_DIR/tgt-val.txt | sort | uniq > $RUN_DIR/val-unique.txt 129 | if [ ! -s $OUTPUT_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt ]; then 130 | onmt_translate --model $MODEL_DIR/${MODEL_ID}_epoch_${EPOCH_NUM}.pt --src $RUN_DIR/val-unique.txt --output $OUTPUT_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt --replace_unk --seed 1 --gpu 0 131 | fi 132 | ``` 133 | Convert SELFIEs of generated molecules to SMILES and get scaffolds of all input validation and generated molecules. 134 | ``` 135 | python ./scripts/selfies_to_smiles.py --in1 $RUN_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt --in2 $RUN_DIR/src-val-unique.txt --out $RUN_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv 136 | python ./scripts/scaffolding.py --in $RUN_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv --out $RUN_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv --column1 'structure' --column2 'id' 137 | ``` 138 | Score generated molecules based on number of scaffold changes, number of r-group changes, number of unique scaffolds and number of new scaffolds. `$METRICS_TABLE` is the csv file where scores will be added. If csv doesn't exist yet, one will be created. 139 | ``` 140 | ## Score model predictions 141 | python ./scripts/scoring.py --in $RUN_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv --metrics_table $METRICS_TABLE --training_data $RUN_DIR/train.csv --model ${MODEL_ID}_epoch_${EPOCH_NUM} --change_count --scaffolds 142 | ``` 143 | It is strongly recommended that you submit this job to a queue using a slurm script, because it may take up to 1 day based on the dataset size. An example of a slurm script can be found in scripts/predictions.sh. 144 | ``` 145 | sbatch --export=RUN_DIR=/PATH/to/training_dataset,MODEL_ID=test_dataset,MODEL_DIR=MODEL_DIR=/PATH/model/name_of_run,OUTPUT_DIR=/PATH/to/predictions,EPOCH_NUM=10,METRICS_TABLE=/PATH/model_scores.csv ./slurm_scripts/predictions.sh 146 | ``` 147 | 148 | An additional slurm script can be run to determine the new smirks predicted by the model by running the following command, where `$THRESHOLD` is the threshold of smirks count above which pngs of smirks will be created and `$PNG_DEST` is an existant directory or one that will be created to save SMIRK pngs to. 149 | ``` 150 | sbatch --export=EPOCH_NUM=10,OUTPUT_DIR=/PATH/to/predictions,DATA_DIR=/PATH/to/paired_dataset,DATA_ID=test_dataset,THRESHOLD=2,PNG_DEST=/DIRECTORY/for/pngs ./slurm_scripts/new_smirks.sh 151 | ``` 152 | 153 | ### Generating Plots 154 | #### **Scaffolding scores** 155 | To generate line plots for scaffold scores over multiple epochs run the following command. Specify which runs in the metrics_table.csv to plot with `$SUBSET` which is a string identifier in the model names. 156 | ``` 157 | mkdir -p $PLOT_DIR 158 | python ./scripts/generating_plots.py --metrics_table $METRICS_TABLE --out $PLOT_DIR --subset $SUBSET, --type scores 159 | ``` 160 | #### **Molecular Property Histograms** 161 | To generate histogram plots comparing molecular properties of generated molecules compared to the input molecules of the model for training, first get the molecular properties of generated and input molecules with the following commands: 162 | ``` 163 | python ./scripts/molecular_properties.py --in $DATA_DIR/${DATA_ID}_mmpdb_input_nostereo.csv --out $DATA_DIR/${DATA_ID}_molecular_properties.csv --smi_col $SMI_COL 164 | python ./scripts/molecular_properties.py --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \--smi_col 'structure' 165 | ``` 166 | Next, to generate the histograms run the following command: 167 | ``` 168 | mkdir -p $PLOT_DIR 169 | python ./scripts/generating_plots.py --in1 $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv --in2 $DATA_DIR/${DATA_ID}_molecular_properties.csv --out $PLOT_DIR --type molecular_properties 170 | ``` 171 | #### **Training Curves and Scaffolding Scores** 172 | To generate stacks plots with perplexity and scaffolding scores first parse the training log and info files to generate csv files with perplexity and accuracy scores. `$IN_FILE` is a txt file with information about the model including the split name and filter used, and `$OUTDEST` is the name of the directory to save all files generated for this script. 173 | ``` 174 | mkdir -p $OUTDEST 175 | python ./scripts/training_curves.py --in $IN_FILE --out $OUTDEST/training_model_info_${MODEL_NUM}.csv --parse_type info 176 | 177 | export NAME=validation_accuracy 178 | grep 'Validation accuracy' $ERR_FILE > $OUTDEST/${NAME}_${MODEL_NUM}.err 179 | python ./scripts/training_curves.py --in $OUTDEST/${NAME}_${MODEL_NUM}.err --out $OUTDEST/${NAME}_${MODEL_NUM}.csv --name $NAME --parse_type val 180 | 181 | export NAME=validation_perplexity 182 | grep 'Validation perplexity' $ERR_FILE > $OUTDEST/${NAME}_${MODEL_NUM}.err 183 | python ./scripts/training_curves.py --in $OUTDEST/${NAME}_${MODEL_NUM}.err --out $OUTDEST/${NAME}_${MODEL_NUM}.csv --name $NAME --parse_type val 184 | 185 | export NAME=training 186 | grep 'Start training loop and validate' $ERR_FILE > $OUTDEST/${NAME}_${MODEL_NUM}.err 187 | grep 'acc:' $ERR_FILE >> $OUTDEST/${NAME}_${MODEL_NUM}.err 188 | python ./scripts/training_curves.py --in $OUTDEST/${NAME}_${MODEL_NUM}.err --out $OUTDEST/${NAME}_${MODEL_NUM}.csv --name $NAME --outpng $OUTDEST --parse_type train 189 | ``` 190 | Finally run the following command to generate the stacked plots. Make sure the `$METRICS_TABLE` csv file contained the scores for the split and filter being used. 191 | ``` 192 | mkdir -p $PLOTDEST 193 | python /gpfs/workspace/users/tysine/Transformer/training_curves.py --val_acc $OUTDEST/validation_accuracy_${MODEL_NUM}.csv --val_ppl $OUTDEST/validation_perplexity_${MODEL_NUM}.csv --train $OUTDEST/training_${MODEL_NUM}.csv --info $OUTDEST/training_model_info_${MODEL_NUM}.csv --metrics $METRICS_TABLE --outpng $PLOTDEST --epoch_cutoff $EPOCH_CUTOFF --parse_type plot 194 | ``` 195 | The following slurm script can be run to automate all these steps, where `$MODEL_NUM` is the job id when training the model: 196 | ``` 197 | sbatch --export=$IN_FILE=/PATH/to/model_info.txt,ERR_FILE=/PATH/to/training_log.txt,OUTDEST=/PATH/to/output_directory,MODEL_NUM=model_num,METRICS_TABLE=/PATH/to/metrics.csv,PLOTDEST=directory_to_save_plots,EPOCH_CUTOFF=32 ./slurm_scripts/training_curves.sh 198 | ``` 199 | 200 | ### Held-Out Target-Specific Data 201 | Scripts related to filtering datasets for target-specific data require a dataset of molecules with a column `tid` refering to the target id of the molecule's target as well as another dataset (refered to as `\PATH\to\target_id_dataset`) mapping target id's (`tid`) to their chembl_id. This target_id_dataset can be downloaded from `chembl_dataset.zip` and is titled `target_information.csv`. To generate the datasets with no target-specific data run the filtering scripts with additional inputs including the name of the target (`$TARGET`): 202 | ``` 203 | python ./scripts/filtering_data.py --in $DATA_DIR/${DATA_ID}_pairs_parsed.csv --all $DATA --smirks $DATA_DIR/${DATA_ID}_counted.csv --out $DATA_DIR/${DATA_ID}_filtered_${TARGET}.csv --size $SAMPLE_SIZE --exclude $EXCLUDE --target 'target chembl_id' --tid \PATH\to\target_id_dataset 204 | ``` 205 | Run the **Dataset Split** and **Training Model** modules to same as described earlier with `$DATA_DIR/${DATA_ID}_filtered_${TARGET}.csv`. Next split the target-specific data temporally and by activity. 206 | 207 | To generate images for the experimental molecules for a specific target, the most similar generated molecule(calculated with Tanimoto similarity) and the input molecule for given generated molecule run the following command, where `$SMI_COL` is the name of the column with SMILEs representations in `--in1` and `$PNG_DEST` is an existant directory or one that will be created to save molecule pngs to. 208 | ``` 209 | python ./scripts/tanimoto_target_specific.py --in1 /PATH/to/target_molecules --in2 /PATH/to/generated_molecules --smi_col $SMI_COL --out $OUTPUT_DIR/top_generated_per_experimental.csv --png_dest $PNG_DEST --generate_png 210 | ``` 211 | 212 | -------------------------------------------------------------------------------- /chembl_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfizer-opensource/transform-molecules/de569e232698f17eca27d4f802792c90300509d5/chembl_dataset.zip -------------------------------------------------------------------------------- /config_examples/build_vocab_config.yml: -------------------------------------------------------------------------------- 1 | ## Where the samples will be written 2 | save_data: .../run/example 3 | ## Where the vocab(s) will be written 4 | src_vocab: .../run/example.vocab.src 5 | tgt_vocab: .../run/example.vocab.tgt 6 | # Prevent overwriting existing files in the folder 7 | overwrite: False 8 | 9 | # Corpus opts: 10 | data: 11 | corpus_1: 12 | path_src: $DATA_DIR/src-train.txt 13 | path_tgt: $DATA_DIR/tgt-train.txt 14 | valid: 15 | path_src: $DATA_DIR/src-val.txt 16 | path_tgt: $DATA_DIR/tgt-val.txt 17 | -------------------------------------------------------------------------------- /scripts/clear_stereochemistry.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from openeye import oechem 4 | 5 | 6 | def clear_stereochemistry(mol): 7 | clear_atom_stereochemistry(mol) 8 | clear_bond_sterochemistry(mol) 9 | oechem.OESuppressHydrogens(mol, False, False, False) 10 | 11 | def clear_atom_stereochemistry(mol): 12 | for atom in mol.GetAtoms(): 13 | chiral = atom.IsChiral() 14 | stereo = oechem.OEAtomStereo_Undefined 15 | v = [] 16 | for nbr in atom.GetAtoms(): 17 | v.append(nbr) 18 | if atom.HasStereoSpecified(oechem.OEAtomStereo_Tetrahedral): 19 | stereo = atom.GetStereo(v, oechem.OEAtomStereo_Tetrahedral) 20 | 21 | if chiral or stereo != oechem.OEAtomStereo_Undefined: 22 | atom.SetStereo(v, oechem.OEAtomStereo_Tetrahedral, oechem.OEAtomStereo_Undefined) 23 | 24 | 25 | def clear_bond_sterochemistry(mol): 26 | for bond in mol.GetBonds(): 27 | if bond.HasStereoSpecified(oechem.OEBondStereo_CisTrans): 28 | for atomB in bond.GetBgn().GetAtoms(): 29 | if atomB == bond.GetEnd(): 30 | continue 31 | for atomE in bond.GetEnd().GetAtoms(): 32 | if atomE == bond.GetBgn(): 33 | continue 34 | v = [] 35 | v.append(atomB) 36 | v.append(atomE) 37 | stereo = bond.SetStereo(v, oechem.OEBondStereo_CisTrans, oechem.OEBondStereo_Undefined) 38 | 39 | def abs_smi(x): 40 | mol = oechem.OEGraphMol() 41 | if oechem.OESmilesToMol(mol, x): 42 | clear_stereochemistry(mol) 43 | return oechem.OEMolToSmiles(mol) 44 | else: 45 | return np.nan 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser(description="Remove stereochemistry from the input data set.") 50 | parser.add_argument("--in",dest="infile",help="whitespace-delimited input file",metavar="in.csv") 51 | parser.add_argument("--out", dest="outfile", help="output file", metavar="out.csv") 52 | 53 | args = parser.parse_args() 54 | n=0 55 | with open(args.infile, 'r') as ifs: 56 | with open(args.outfile, 'w') as ofs: 57 | for line in ifs: 58 | if n==0: 59 | ofs.write(line) 60 | n=1 61 | else: 62 | parsed = line.strip().split(',') 63 | if ('.' not in parsed[0]): 64 | ofs.write(f"{abs_smi(parsed[0])},{parsed[1]}\n") 65 | 66 | 67 | -------------------------------------------------------------------------------- /scripts/counting_smirks.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description="Count how many times each smirks occurs in the mmpdb pairs.") 7 | parser.add_argument("--in", dest="infile", help="comma-delimited input file", metavar="in.csv") 8 | parser.add_argument("--out", dest="outfile", help="output file", metavar="out.csv") 9 | args = parser.parse_args() 10 | 11 | smirks = dict() 12 | 13 | with open(args.infile, 'r') as ifs: 14 | for line in ifs: 15 | parsed = line.strip().split(',') 16 | if parsed[4] in smirks: 17 | smirks[parsed[4]] += 1 18 | else: 19 | smirks[parsed[4]] = 1 20 | df = pd.DataFrame({'smirks':list(smirks.keys()), 'counted':list(smirks.values())}) 21 | df.sort_values(by='counted', ascending=False, inplace=True) 22 | df.to_csv(args.outfile, index=False) 23 | -------------------------------------------------------------------------------- /scripts/dataset_split.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import math 4 | import argparse 5 | import selfies 6 | import re 7 | from scaffolding import get_murcko_scaffold 8 | 9 | 10 | def smiles_to_selfies(x_smiles, token_sep=True): 11 | x_selfies = [] 12 | for i in range(x_smiles.shape[0]): 13 | ax = selfies.encoder(x_smiles[i]) 14 | if ax != -1: 15 | try: 16 | if token_sep: 17 | sx = re.findall(r"\[[^\]]*\]", ax) 18 | ax = " ".join(sx) 19 | x_selfies.append(ax) 20 | except: 21 | x_selfies.append("NaN") 22 | else: 23 | x_selfies.append("NaN") 24 | return x_selfies 25 | 26 | def temporal_split(infile, args): 27 | with open(infile, 'r') as ifs: 28 | with open(f"{args.outdest}/test.csv", 'w') as test: 29 | with open(f"{args.outdest}/train.csv", 'w') as train: 30 | with open(f"{args.outdest}/val.csv", 'w') as val: 31 | train_pairs, val_pairs, test_pairs = set(), set(), set() 32 | next(ifs) # skipping the first line with the header in the input file 33 | for line in ifs: 34 | smiles1, smiles2, name1, name2, smirk, rgroup = line.strip().split(',') 35 | year_name1, year_name2 = years_dict[name1], years_dict[name2] 36 | year_name1 = int(year_name1) if not np.isnan(year_name1) else 9999 # all structures without dates have to go to test set to make sure there is no target leak 37 | year_name2 = int(year_name2) if not np.isnan(year_name2) else 9999 38 | if (name1 in years_dict and name2 in years_dict) and year_name1 < int(args.year1) and year_name2 < int(args.year1): 39 | if (name1, name2) not in train_pairs: 40 | train_pairs.add((name1, name2)) 41 | train.write(line) 42 | if args.augment: 43 | if (name2, name1) not in train_pairs: 44 | train_pairs.add((name2, name1)) 45 | train.write(f"{smiles2},{smiles1},{name2},{name1},{smirk.split('>>')[1]}>>{smirk.split('>>')[0]},{rgroup}\n") 46 | elif (name1 in years_dict and name2 in years_dict) and year_name1 <= int(args.year2) and year_name2 <= int(args.year2): 47 | if (name1, name2) not in val_pairs: 48 | val_pairs.add((name1, name2)) 49 | val.write(line) 50 | if args.augment: 51 | if (name2, name1) not in val_pairs: 52 | val_pairs.add((name2, name1)) 53 | val.write(f"{smiles2},{smiles1},{name2},{name1},{smirk.split('>>')[1]}>>{smirk.split('>>')[0]},{rgroup}\n") 54 | else: 55 | if (name1, name2) not in test_pairs: 56 | test_pairs.add((name1, name2)) 57 | test.write(line) 58 | if args.augment: 59 | if (name2, name1) not in test_pairs: 60 | test_pairs.add((name2, name1)) 61 | test.write(f"{smiles2},{smiles1},{name2},{name1},{smirk.split('>>')[1]}>>{smirk.split('>>')[0]},{rgroup}\n") 62 | 63 | def to_selfies(data, name, outdest): 64 | data.columns = [0,1] 65 | smiles1 = np.array(data[0]) 66 | smiles2 = np.array(data[1]) 67 | selfies1 = smiles_to_selfies(smiles1) 68 | selfies2 = smiles_to_selfies(smiles2) 69 | with open(f"{outdest}/src-{name}.txt", 'w') as ofs: 70 | for selfie in selfies1: 71 | ofs.write(f"{selfie}\n") 72 | with open(f"{outdest}/tgt-{name}.txt", 'w') as ofs: 73 | for selfie in selfies2: 74 | ofs.write(f"{selfie}\n") 75 | 76 | def get_train_scaffolds(train_file): 77 | train = pd.read_csv(train_file, header=None) 78 | train[6] = train[0].apply(get_murcko_scaffold) 79 | train[7] = train[1].apply(get_murcko_scaffold) 80 | train.to_csv(train_file, header=False, index=False) 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser(description="Split dataset into train, validation and test subsets. Augment if requested.") 84 | parser.add_argument("--in", dest="infile", help="comma-delimited input file", metavar="in.csv") 85 | parser.add_argument("--timestamps", dest="timestamps", help="comma-delimited file with years for an earliest publication for each ChEMBL molecule", metavar="all.csv") 86 | parser.add_argument("--out", dest="outdest", help="output destination") 87 | # parser.add_argument("--t", dest="percent_train", help="amount to split test-train", default=0.7) 88 | parser.add_argument("--ids", dest="ids_col", help="name of column with ids", default="chembl_id") 89 | parser.add_argument("--years", dest="years_col", help="name of column with years", default="year") 90 | parser.add_argument("--year_train", dest="year1", help="train set will include pairs of molecules, each of which was published before this year (exclusively)", default=2009) 91 | parser.add_argument("--year_test", dest="year2", help="test set will include pairs of molecules, either or both of which was published after this year (exclusively)", default=2013) 92 | parser.add_argument("--augment", dest="augment", help="for each (mol1,mol2), add also (mol2,mol1) if not there yet", action='store_true') 93 | args = parser.parse_args() 94 | 95 | years = pd.read_csv(args.timestamps, usecols=[args.ids_col, args.years_col]) 96 | years_dict = {k:v for k,v in zip(list(years[args.ids_col]),list(years[args.years_col]))} 97 | del years 98 | print('Starting split') 99 | temporal_split(args.infile, args) 100 | 101 | print('Converting SMILES to SELFIES') 102 | to_selfies(pd.read_csv(f"{args.outdest}/train.csv", usecols=[0,1]), 'train', args.outdest) 103 | to_selfies(pd.read_csv(f"{args.outdest}/val.csv", usecols=[0,1]), 'val', args.outdest) 104 | to_selfies(pd.read_csv(f"{args.outdest}/test.csv", usecols=[0,1]), 'test', args.outdest) 105 | 106 | print('Getting training scaffolds') 107 | get_train_scaffolds(f"{args.outdest}/train.csv") 108 | -------------------------------------------------------------------------------- /scripts/filtering_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import random 4 | import argparse 5 | 6 | def main(args): 7 | #generating dictionary of smirks and their counts for the top N widespread transformations 8 | count_df = pd.read_csv(args.smirks, header=None) 9 | count_df.drop(index=count_df.index[0],axis=0,inplace=True) 10 | count_df.columns = ['smirks','counted'] 11 | count_df['counted'] = np.int64(count_df['counted']) 12 | count_df.sort_values(by='counted',ascending=False, inplace=True) 13 | 14 | smirk_counts = dict(zip(list(count_df['smirks']), list(count_df['counted']))) 15 | 16 | exclude = count_df[count_df['counted'] < int(args.exclude)] 17 | smirk_exclude = set(exclude['smirks']) 18 | 19 | 20 | #getting dictionary of target tid to chembl_id if filtering by a target 21 | if args.target != None: 22 | tid_df = pd.read_csv(args.target_dict, usecols=['tid', 'chembl_id']) 23 | tid_dict = dict(zip(list(tid_df['chembl_id']), list(tid_df['tid']))) 24 | target_id = tid_dict[args.target] 25 | del tid_df 26 | all_df = pd.read_csv(args.alldata, usecols=['chembl_id', 'tid']) 27 | chembl_dict = dict(zip(list(all_df['chembl_id']), list(all_df['tid']))) 28 | del all_df 29 | 30 | #getting indices of all smirks and exclude those pairs with a molecule that targets the target 31 | with open(args.infile, 'r') as ifs: 32 | smirk_indices = {k:[] for k in list(smirk_counts.keys())} 33 | for index, line in enumerate(ifs): 34 | _,_,id1,id2,smirk,_ = line.strip().split(',') 35 | if smirk in smirk_indices: 36 | if args.target == None or (chembl_dict[id1] != target_id and chembl_dict[id2] != target_id): 37 | smirk_indices[smirk].append(index) 38 | else: 39 | print(f'exclude by target: {index}') 40 | 41 | #randomly selecting indices for downsampling 42 | smirk_downsampled = dict() 43 | for smr, indices in smirk_indices.items(): 44 | if not args.sample or len(indices) < int(args.sample): 45 | smirk_downsampled[smr] = indices 46 | else: 47 | smirk_downsampled[smr] = random.sample(indices, int(args.sample)) 48 | 49 | #write sampled pairs to outfile 50 | with open(args.infile, 'r') as ifs: 51 | with open(args.outfile, 'w') as ofs: 52 | for index, line in enumerate(ifs): 53 | _,_,_,_,smirk,_ = line.strip().split(',') 54 | if smirk not in smirk_exclude: 55 | if (smirk in smirk_downsampled and index in smirk_downsampled[smirk]) or (smirk not in smirk_downsampled): 56 | ofs.write(line) 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser(description="Filters and downsamples the data based on smirk counts.") 60 | parser.add_argument("--in", dest="infile", help="comma-delimited input file", metavar="in.csv") 61 | parser.add_argument("--all", dest="alldata", help="comma-delimited file with all chembl_data", metavar="all.csv") 62 | parser.add_argument("--smirks", dest="smirks", help="file with counts of smirks in infile", metavar="smirks.csv") 63 | parser.add_argument("--out", dest="outfile", help="output file", metavar="out.csv") 64 | parser.add_argument("--size", dest="sample", help="number to sample from each smirk", default=None) 65 | parser.add_argument("--exclude", dest="exclude", help="threshold to exclude smirks at", default=2) 66 | parser.add_argument("--target", dest="target", help="chembl_id of target to exclude", default=None) 67 | parser.add_argument("--tid", dest="target_dict", help="location of target_dictionary from chembl_id", default=None) 68 | args = parser.parse_args() 69 | 70 | main(args) 71 | -------------------------------------------------------------------------------- /scripts/generating_plots.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import argparse 5 | 6 | def molecular_properties_hist(args): 7 | labels = {'mol_wt':'Molecular Weights', 'h_donors':'Hydrogen Bond Donors', 'h_acceptors':'Hydrogen Bond Acceptors', 'mol_logp':'Octanol-Water Partition Coefficients'} 8 | generated_df = pd.read_csv(args.infile1) 9 | generated_df['molwt_diff'] = (generated_df['mol_wt_structure'] - generated_df['mol_wt_id']) 10 | generated_df['logp_diff'] = (generated_df['mol_logp_structure'] - generated_df['mol_logp_id']) 11 | all_df = pd.read_csv(args.infile2) 12 | 13 | fig, ((ax1, ax2,ax5), (ax3, ax4,ax6)) = plt.subplots(2, 3, figsize=(16, 8)) 14 | x='mol_wt' 15 | n, bins, patches = ax2.hist([generated_df[f"{x}_structure"],generated_df[f"{x}_id"]], bins = 30, histtype='step', density=True) 16 | ax2.set_xlabel(labels[x]) 17 | x='h_donors' 18 | n, bins, patches = ax1.hist([generated_df[f"{x}_structure"],generated_df[f"{x}_id"]], bins = 14, histtype='step', density=True) 19 | ax1.set_xlabel(labels[x]) 20 | x='h_acceptors' 21 | n, bins, patches = ax3.hist([generated_df[f"{x}_structure"],generated_df[f"{x}_id"]], bins = 20, histtype='step', density=True) 22 | ax3.set_xlabel(labels[x]) 23 | ax3.set_xticks([2,4,6,8,10,12,14,16]) 24 | x='mol_logp' 25 | n, bins, patches = ax4.hist([generated_df[f"{x}_structure"],generated_df[f"{x}_id"]], bins = 30, histtype='step', density=True) 26 | ax4.set_xlabel(labels[x]) 27 | 28 | n, bins, patches = ax5.hist(generated_df['molwt_diff'], bins = 30, histtype='step') 29 | ax5.set_xlabel('Molecular Weight Differences') 30 | 31 | n, bins, patches = ax6.hist(generated_df['logp_diff'], bins = 30, histtype='step') 32 | ax6.set_xlabel('Octanol-Water Partition Coefficient Differences') 33 | 34 | fig.suptitle(f'Molecular Properties Distributions Comparison Generated vs. Input',fontsize=15) 35 | ax1.legend(labels=['Generated Molecules', 'Input Molecules']) 36 | plt.savefig(f"{args.outdest}/generated_input_plot.png") 37 | 38 | plt.clf() 39 | 40 | fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8)) 41 | x='mol_wt' 42 | n, bins, patches = ax2.hist([generated_df[f"{x}_structure"],all_df[f"{x}_{args.smi_col}"]], bins = 30, histtype='step', density=True) 43 | ax2.set_xlabel(labels[x]) 44 | x='h_donors' 45 | n, bins, patches = ax1.hist([generated_df[f"{x}_structure"],all_df[f"{x}_{args.smi_col}"]], bins = 15, histtype='step', density=True) 46 | ax1.set_xlabel(labels[x]) 47 | ax1.set_xticks([2,4,6,8,10,12,14]) 48 | x='h_acceptors' 49 | n, bins, patches = ax3.hist([generated_df[f"{x}_structure"],all_df[f"{x}_{args.smi_col}"]], bins = 30, histtype='step', density=True) 50 | ax3.set_xlabel(labels[x]) 51 | x='mol_logp' 52 | n, bins, patches = ax3.hist([generated_df[f"{x}_structure"],all_df[f"{x}_{args.smi_col}"]], bins = 30, histtype='step', density=True) 53 | ax3.set_xlabel(labels[x]) 54 | 55 | fig.suptitle(f'Molecular Properties Distributions Comparison Generated vs. CHEMBL',fontsize=15) 56 | ax1.legend(labels=['Generated Molecules', 'CHEMBL Molecules']) 57 | plt.savefig(f"{args.outdest}/generated_chembl_plot.png") 58 | 59 | 60 | def scores_lineplot(args): 61 | models_df = pd.read_csv(args.metrics_table) 62 | models_df = models_df.set_index('model') 63 | models_df.sort_values(by=['model_name','epoch'],inplace=True) 64 | 65 | plot_save_dict = {'Count':'count', 'Scaffold Change Count':'scaffold_change_count', 'R-Group Change Count':'r_group_change_count', 'Unique Scaffolds':'unique_scaffolds', 'New Scaffolds':'new_scaffolds'} 66 | 67 | models = list(set(models_df[models_df.model_name.notna()]['model_name'])) 68 | models.sort() 69 | 70 | fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, sharex='col', figsize=(18, 8)) 71 | y = 'Count' 72 | for x in [n for n in models if (type(n)==str) and (args.subset in n)]: 73 | ax3.plot(models_df[(models_df.model_name == x)&(models_df[y] != 0)]['epoch'], models_df[(models_df.model_name == x)&(models_df[y] != 0)][y], label = 'pre-triage') 74 | ax3.set_ylabel(y) 75 | ax6.set_xlabel('Epoch Number') 76 | ax3.tick_params(axis='y', which='major', labelsize=8) 77 | 78 | y = 'Scaffold Change Count' 79 | for x in [n for n in models if (type(n)==str) and (args.subset in n)]: 80 | ax2.plot(models_df[(models_df.model_name == x)&(models_df[y] != 0)]['epoch'], models_df[(models_df.model_name == x)&(models_df[y] != 0)][y], label = 'pre-triage') 81 | ax2.set_ylabel(y) 82 | ax2.tick_params(axis='y', which='major', labelsize=8) 83 | 84 | y = 'R-Group Change Count' 85 | for x in [n for n in models if (type(n)==str) and (args.subset in n)]: 86 | ax5.plot(models_df[(models_df.model_name == x)&(models_df[y] != 0)]['epoch'], models_df[(models_df.model_name == x)&(models_df[y] != 0)][y], label = 'pre-triage') 87 | ax5.set_ylabel(y) 88 | ax5.set_xlabel('Epoch Number') 89 | ax5.tick_params(axis='y', which='major', labelsize=8) 90 | 91 | y = 'Unique Scaffolds' 92 | for x in [n for n in models if (type(n)==str) and (args.subset in n)]: 93 | ax4.plot(models_df[(models_df.model_name == x)&(models_df[y] != 0)]['epoch'], models_df[(models_df.model_name == x)&(models_df[y] != 0)][y], label = 'pre-triage') 94 | ax4.set_ylabel(y) 95 | ax4.set_xlabel('Epoch Number') 96 | ax4.tick_params(axis='y', which='major', labelsize=8) 97 | 98 | y = 'New Scaffolds' 99 | for x in [n for n in models if (type(n)==str) and (args.subset in n)]: 100 | ax1.plot(models_df[(models_df.model_name == x)&(models_df[y] != 0)]['epoch'], models_df[(models_df.model_name == x)&(models_df[y] != 0)][y], label = 'pre-triage') 101 | ax1.set_ylabel(y) 102 | ax1.tick_params(axis='y', which='major', labelsize=8) 103 | 104 | plt.setp(ax1.get_xticklabels(), visible=False) 105 | plt.subplots_adjust(hspace=.0) 106 | # remove last tick label for the second subplot 107 | #yticks = ax2.yaxis.get_major_ticks() 108 | #yticks[-1].label1.set_visible(False) 109 | fig.suptitle(f"Chemical scores for {args.subset}", fontsize=16) 110 | plt.savefig(f"{args.outdest}/{args.subset}_chemical_scores.png") 111 | 112 | def main(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--metrics_table", dest="metrics_table", help="File with metrics", metavar="in.csv") 115 | parser.add_argument("--in1", dest='infile1', help="csv file containing molecular properties of generated and input molecules", metavar="in1.csv") 116 | parser.add_argument("--in2", dest='infile2', help="csv file containing molecular properties of training molecules", metavar="in2.csv") 117 | parser.add_argument("--out", dest="outdest", help="output destination", metavar="input.txt") 118 | parser.add_argument("--subset", dest="subset", help="model name containing subset string will be used") 119 | parser.add_argument("--type", dest="plot_type", help="type of plot to generate") 120 | parser.add_argument("--smi_col", dest="smi_col", help="name of smiles column in infile2") 121 | args = parser.parse_args() 122 | 123 | if args.plot_type == 'scores': 124 | scores_lineplot(args) 125 | elif args.plot_type == 'molecular_properties': 126 | molecular_properties_hist(args) 127 | 128 | if __name__=='__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /scripts/mmpdb_prep.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(description="Prepare input for mmpdb.") 6 | parser.add_argument("--in", dest="infile", help="all molecules", metavar="in.csv") 7 | parser.add_argument("--out", dest="outfile", metavar="out.csv") 8 | parser.add_argument("--smiles", dest="smiles_col", help="name of column with smiles representations") 9 | parser.add_argument("--ids", dest="ids_col", help="name of column with ids") 10 | args = parser.parse_args() 11 | 12 | pd.read_csv(args.infile, usecols=[args.smiles_col, args.ids_col])[[args.smiles_col, args.ids_col]].to_csv(args.outfile, index=False) 13 | 14 | if __name__=='__main__': 15 | main() 16 | -------------------------------------------------------------------------------- /scripts/molecular_properties.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import Crippen 2 | from rdkit.Chem import Lipinski 3 | from rdkit.Chem import Descriptors 4 | from rdkit import Chem 5 | import pandas as pd 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | def lipinski(smiles): 11 | mol_wt = [] 12 | hdonors = [] 13 | haccept = [] 14 | logp = [] 15 | for smi in smiles: 16 | mol = Chem.MolFromSmiles(smi) 17 | if mol is None: 18 | mol_wt.append(pd.NA) 19 | hdonors.append(pd.NA) 20 | haccept.append(pd.NA) 21 | logp.append(pd.NA) 22 | else: 23 | hdonors.append(Lipinski.NumHDonors(mol)) 24 | haccept.append(Lipinski.NumHAcceptors(mol)) 25 | mol_wt.append(Descriptors.MolWt(mol)) 26 | logp.append(Crippen.MolLogP(mol)) 27 | 28 | return mol_wt, hdonors,haccept,logp 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser(description="Get molecular properties of input molecules") 33 | parser.add_argument("--in", dest="infile", help="csv inputfile with smiles column", metavar="input.csv") 34 | parser.add_argument("--out", dest="outfile", help="outfile containing molecular properties", metavar="output.csv") 35 | parser.add_argument("--smi_col", dest="smi_col", help="name of smiles column") 36 | args = parser.parse_args() 37 | 38 | df = pd.read_csv(args.infile) 39 | mol_wt, hdonors, haccept, logp = lipinski(list(df[args.smi_col])) 40 | 41 | df[f'mol_wt_{args.smi_col}'] = mol_wt 42 | df[f'h_donors_{args.smi_col}'] = hdonors 43 | df[f'h_acceptors_{args.smi_col}'] = haccept 44 | df[f'mol_logp_{args.smi_col}'] = logp 45 | 46 | df.to_csv(args.outfile, index=False) 47 | 48 | if __name__=='__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /scripts/new_smirks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | import subprocess 6 | from openeye import oechem, oedepict 7 | 8 | def make_smi_df(infile, outfile): 9 | inf = pd.read_csv(infile) 10 | otf = inf['id'].append(inf['structure'], ignore_index=True) 11 | otf = pd.DataFrame({'smiles':otf, 'id':range(otf.shape[0])}) 12 | otf.to_csv(outfile, index=False, header=False) 13 | 14 | def get_smirks(infile1, infile2, outfile): 15 | all_smirks = pd.read_csv(infile1) 16 | all_smirks.columns = ['id', 'structure', 'ch1', 'ch2', 'smirk', 'other'] 17 | all_smirks.drop(columns=['ch1','ch2','other'], inplace=True) 18 | generated_pairs = pd.read_csv(infile2) 19 | generated_pairs = generated_pairs.merge(all_smirks, how='left', left_on=['id','structure'], right_on=['id','structure']) 20 | generated_pairs.to_csv(outfile, index=False) 21 | 22 | 23 | def counting_smirks(infile, training): 24 | training_smirks = set(n for n in list(pd.read_csv(training)['smirks'])) 25 | smirks_dict = dict() 26 | 27 | for n in list(pd.read_csv(infile)['smirk']): 28 | if n not in training_smirks and n != np.nan: 29 | if n not in smirks_dict: 30 | smirks_dict[n] = 1 31 | else: 32 | smirks_dict[n] += 1 33 | return len(smirks_dict), sum(list(smirks_dict.values())), smirks_dict 34 | 35 | def generate_png(args, df): 36 | for n in range(df.shape[0]): 37 | if df.loc[n,'counted']>=int(args.threshold) and type(df.loc[n,'smirks'])==str: 38 | smiles1, smiles2 = df.loc[n,'smirks'].split('>>') 39 | 40 | mol = oechem.OEGraphMol() 41 | oechem.OESmilesToMol(mol, smiles1) 42 | simple_png1 = f"{args.png_dest}/smirk1_{n}.png" 43 | oedepict.OEPrepareDepiction(mol) 44 | width, height = 600,600 45 | opts = oedepict.OE2DMolDisplayOptions(width,height,oedepict.OEScale_Default*10) 46 | opts.SetTitleLocation(oedepict.OETitleLocation_Hidden) 47 | disp = oedepict.OE2DMolDisplay(mol, opts) 48 | oedepict.OERenderMolecule(simple_png1, disp) 49 | 50 | mol = oechem.OEGraphMol() 51 | oechem.OESmilesToMol(mol, smiles2) 52 | simple_png2 = f"{args.png_dest}/smirk2_{n}.png" 53 | oedepict.OEPrepareDepiction(mol) 54 | width, height = 600,600 55 | if smiles2 == '[*:1][H]': 56 | opts = oedepict.OE2DMolDisplayOptions(width,height,oedepict.OEScale_Default*7) 57 | else: 58 | opts = oedepict.OE2DMolDisplayOptions(width,height,oedepict.OEScale_Default*10) 59 | opts.SetTitleLocation(oedepict.OETitleLocation_Hidden) 60 | disp = oedepict.OE2DMolDisplay(mol, opts) 61 | oedepict.OERenderMolecule(simple_png2, disp) 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser(description="Get smirks of model predictions") 65 | parser.add_argument("--in", dest="infile", help="csv file with input dependent on which function {make_smi_df, get_smirks, new_smirks} is called", metavar="input.txt") 66 | parser.add_argument("--out", dest="outfile", help="output file", metavar="output.csv") 67 | parser.add_argument("--in2", dest="infile2", help="second csv infile dependent on which function {get_smirks, new_smirks} is called", metavar="infile2.csv") 68 | parser.add_argument("--training", dest="training", help="original training file", metavar="training.csv") 69 | parser.add_argument("--make_smi_df", dest="make_smi_df", help="run function to generate all smiles csv", action='store_true') 70 | parser.add_argument("--get_smirks", dest="get_smirks", help="run function to get smirks of generated pairs", action='store_true') 71 | parser.add_argument("--new_smirks", dest="new_smirks", help="run function to count new smirks", action='store_true') 72 | parser.add_argument("--threshold", dest="threshold", help="smirk count threshold to generate pngs", default=1000) 73 | parser.add_argument("--png_dest", dest="png_dest", help="directory to save smirk pngs") 74 | args = parser.parse_args() 75 | 76 | if args.make_smi_df: 77 | make_smi_df(args.infile, args.outfile) 78 | elif args.get_smirks: 79 | get_smirks(args.infile, args.infile2, args.outfile) 80 | elif args.new_smirks: 81 | n_unique_smirks, n_all_smirks, smirk_dict = counting_smirks(args.infile, args.infile2) 82 | count_df = pd.DataFrame({'smirks':list(smirk_dict.keys()), 'counted': list(smirk_dict.values())}) 83 | count_df.sort_values(by='counted', inplace=True, ascending=False) 84 | count_df.to_csv(args.outfile, index=False) 85 | generate_png(args, count_df) 86 | 87 | if __name__=='__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /scripts/parsing_pairs.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sys 4 | import argparse 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="Parse mmpdb index output csv.") 8 | parser.add_argument("--in", dest="infile", help="whitespace-delimited input file from mmpdb index", metavar="in.csv") 9 | parser.add_argument("--out", dest="outfile", help="output file", metavar="out.csv") 10 | args = parser.parse_args() 11 | pairs_found = set() 12 | 13 | with open(args.infile, 'r') as ifs: 14 | with open(args.outfile, 'w') as ofs: 15 | for line in ifs: 16 | parsed = line.strip().split('\t') 17 | if (parsed[2],parsed[3]) not in pairs_found and (parsed[3],parsed[2]) not in pairs_found: 18 | pairs_found.add((parsed[2],parsed[3])) 19 | ofs.write(f"{','.join(parsed)}\n") 20 | -------------------------------------------------------------------------------- /scripts/renaming_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(description="Rename .pt files for trained models: replace numbers of steps by numbers of epochs.") 6 | parser.add_argument("--models", dest="models", help="model directory") 7 | parser.add_argument("--batch_size", dest="batch_size", help="batch size") 8 | parser.add_argument("--train_size", dest="train_size", help="training data size") 9 | parser.add_argument("--startswith", dest="startswith", help="only .pt files with names starting with this string will be renamed", default="") 10 | args = parser.parse_args() 11 | 12 | epoch_steps = round(int(args.train_size)/int(args.batch_size)) 13 | 14 | for model in os.listdir(args.models): 15 | if model.startswith(args.startswith): 16 | if '_step_' in model: 17 | prefix, suffix = model.split('_step_') 18 | step_num = int(suffix[:-3]) 19 | epoch_num = round(step_num/epoch_steps) 20 | os.rename(f"{args.models}/{model}", f"{args.models}/{prefix}_epoch_{epoch_num}.pt") 21 | 22 | if __name__=='__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /scripts/scaffolding.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | from rdkit.Chem.Scaffolds import MurckoScaffold 4 | from rdkit.Chem import * 5 | 6 | def get_murcko_scaffold(smi): 7 | mol = MolFromSmiles(smi) 8 | if not mol: 9 | return "" 10 | return MolToSmiles(MurckoScaffold.GetScaffoldForMol(mol)) 11 | 12 | def get_canonical_smiles(x): 13 | mol = oechem.OEGraphMol() 14 | oechem.OESmilesToMol(mol, x) 15 | return oechem.OECreateCanSmiString(mol) 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="Make scaffold of input molecules") 19 | parser.add_argument("--in", dest="infile", help="comma-delimited input file of original and predicted smiles", metavar="input.txt") 20 | parser.add_argument("--out", dest="outfile", help="output file", metavar="output.csv") 21 | parser.add_argument("--column1", dest="column1", help="Name of column with smiles") 22 | parser.add_argument("--column2", dest="column2", help="Name of column with smiles", default=False) 23 | args = parser.parse_args() 24 | 25 | df = pd.read_csv(args.infile) 26 | df = df[df[args.column1].apply(pd.isna)==False] 27 | if args.column2: 28 | df = df[df[args.column2].apply(pd.isna)==False] 29 | df[args.column2+'_canonical'] = df[args.column2].apply(get_canonical_smiles) 30 | df[args.column2+'_scaffold'] = df[args.column2+'_canonical'].apply(get_murcko_scaffold) 31 | df[args.column1+'_canonical'] = df[args.column1].apply(get_canonical_smiles) 32 | df[args.column1+'_scaffold'] = df[args.column1+'_canonical'].apply(get_murcko_scaffold) 33 | df.to_csv(args.outfile, index=False) 34 | 35 | if __name__=='__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /scripts/scoring.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import argparse 4 | from os import path 5 | 6 | def change_count(args): 7 | inf = pd.read_csv(args.infile) 8 | inf['change_type'] = ['same' if inf.loc[n]['id_canonical'] == inf.loc[n]['structure_canonical'] else 'r-group' if inf.loc[n]['id_scaffold'] == inf.loc[n]['structure_scaffold'] else 'scaffold' for n in range(inf.shape[0])] 9 | metrics = pd.read_csv(args.metrics) 10 | metrics = metrics.set_index('model') 11 | 12 | metrics.loc[args.model,'Count'] = inf.shape[0] 13 | metrics.loc[args.model,'Scaffold Change Count'] = inf[inf['change_type'] == 'scaffold'].shape[0] 14 | metrics.loc[args.model,'R-Group Change Count'] = inf[inf['change_type'] == 'r-group'].shape[0] 15 | metrics.to_csv(args.metrics) 16 | 17 | def scaffolds(args): 18 | inf = pd.read_csv(args.infile) 19 | train = pd.read_csv(args.training_data, header=None) 20 | metrics = pd.read_csv(args.metrics) 21 | metrics = metrics.set_index('model') 22 | 23 | metrics.loc[args.model,'Unique Scaffolds'] = len(set(inf['structure_scaffold'])) 24 | metrics.loc[args.model,'Validation Scaffolds'] = len(set(inf['id_scaffold'])) 25 | metrics.loc[args.model,'New Scaffolds'] = len(set(inf['structure_scaffold']).difference(*[set(inf['id_scaffold']), set(train[6]), set(train[7])])) 26 | metrics.loc[args.model,'Training Scaffolds'] = len(set(train[6]).union(set(train[7]))) 27 | metrics.to_csv(args.metrics) 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description="Calculate chemical scores of generated molecules from a model") 31 | parser.add_argument("--in", dest="infile", help="csv input file of original and predicted smiles, and scaffolds", metavar="input.txt") 32 | parser.add_argument("--metrics_table", dest="metrics", help="File with metrics so far", metavar="metrics.csv") 33 | parser.add_argument("--training_data", dest="training_data", help="Training dataset file", metavar="train.csv") 34 | parser.add_argument("--model", dest="model", help="name of model") 35 | parser.add_argument("--change_count", dest="change_count", action='store_true') 36 | parser.add_argument("--scaffolds", dest="scaffolds", action='store_true') 37 | args = parser.parse_args() 38 | 39 | if not path.isfile(args.metrics): 40 | metrics_df = pd.DataFrame(columns=["model","Count","Scaffold Change Count","R-Group Change Count","Unique Scaffolds","New Scaffolds","Validation Scaffolds","Training Scaffolds"]) 41 | metrics_df = metrics_df.set_index('model') 42 | metrics_df.to_csv(args.metrics) 43 | 44 | metrics_df = pd.read_csv(args.metrics) 45 | metrics_df = metrics_df.set_index('model') 46 | if args.model not in metrics_df.index.values.tolist(): 47 | metrics_df = metrics_df.reindex(metrics_df.index.values.tolist() + [args.model]) 48 | model_id, epoch = tuple(args.model.split('_epoch_')) 49 | metrics_df.loc[args.model,'model_name'] = model_id 50 | metrics_df.loc[args.model,'epoch'] = int(epoch) 51 | metrics_df.to_csv(args.metrics) 52 | del metrics_df 53 | 54 | if args.change_count: 55 | change_count(args) 56 | 57 | if args.scaffolds: 58 | scaffolds(args) 59 | 60 | if __name__=='__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /scripts/selfies_to_smiles.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | import selfies 4 | 5 | def selfies_to_smiles(x_selfies): 6 | x_selfies = x_selfies.replace(" ", "") 7 | return selfies.decoder(x_selfies) 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description="Parse mmpdb index output csv.") 11 | parser.add_argument("--in1", dest="infile1", help="txt file with generated selfies", metavar="input1.txt") 12 | parser.add_argument("--in2", dest="infile2", help="txt file with original selfies", metavar="input2.txt") 13 | parser.add_argument("--out", dest="outfile", help="output file", metavar="output.csv") 14 | args = parser.parse_args() 15 | print('args') 16 | 17 | selfs = [line.strip() for line in open(args.infile1, 'r').readlines()] 18 | print('selfies') 19 | original_selfies = [line.strip() for line in open(args.infile2, 'r'). readlines()] 20 | print('original selfs') 21 | smi_preds = list(map(selfies_to_smiles, selfs)) 22 | print('smiles') 23 | orig_smiles = list(map(selfies_to_smiles, original_selfies)) 24 | print('original smiles') 25 | 26 | #changed column names so can be inputted into triage.py 27 | pd.DataFrame({'id':orig_smiles, 'structure':smi_preds}).to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /scripts/target_specific.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | import random 5 | from operator import itemgetter 6 | from rdkit import Chem, DataStructs 7 | from rdkit.Chem import AllChem, PandasTools 8 | from rdkit.Chem.AtomPairs import Pairs 9 | 10 | def temporal_split(df, outdest, target): 11 | df.sort_values(by='year', inplace=True) 12 | temporal_df = df[df['year'].notna()] 13 | top_50 = int(round(temporal_df.shape[0]/2)) 14 | temporal_df.head(top_50).to_csv(f"{outdest}/early_years_{target}.csv") 15 | temporal_df.tail(temporal_df.shape[0]-top_50).to_csv(f"{outdest}/later_years_{target}.csv) 16 | 17 | def activity_split(df, outdest, target): 18 | df.sort_values(by=args.activity_col, inplace=True) 19 | activity_df = df[df[args.activity_col].notna()] 20 | top_50 = int(round(temporal_df.shape[0]*0.05)) 21 | activity_df.head(top_50).to_csv(f"{outdest}/most_active_{target}.csv") 22 | activity_df.tail(temporal_df.shape[0]-top_50).to_csv(f"{outdest}/least_active_{target}.csv) 23 | 24 | def fp_2D(smi, fptype='ECPF4'): 25 | m1 = Chem.MolFromSmiles(smi) 26 | if not m1: 27 | print(f"RDKit was unable to parse {smi}.", file=sys.stderr) 28 | return None 29 | 30 | if fptype=='ECFP4': 31 | return AllChem.GetMorganFingerprint(m1, 2) 32 | elif fptype=='ECFP6': 33 | return AllChem.GetMorganFingerprint(m1, 3) 34 | elif fptype=='AP': 35 | return Pairs.GetAtomPairFingerprint(m1) 36 | else: 37 | raise NotImplementedError(f"Fingerprint type {fptype} is not implemented.") 38 | 39 | def tanimoto_similarity(fp1, fp2): 40 | if (fp1 is not None) and (fp2 is not None): 41 | result = DataStructs.TanimotoSimilarity(fp1, fp2) 42 | else: 43 | result = 0.0 44 | return result 45 | 46 | def tanimoto_similarity_from_smiles(smi1, smi2, fptype='ECFP4'): 47 | fp1 = fp_2D(smi1, fptype) 48 | fp2 = fp_2D(smi2, fptype) 49 | return tanimoto_similarity(fp1, fp2) 50 | 51 | 52 | def calculate_tanimoto_similarity(actual, pred, dic): 53 | top_scores = [] 54 | for sm1 in actual: 55 | top = list(max([(sm1, sm2, tanimoto_similarity_from_smiles(sm1, sm2)) for sm2 in pred if type(sm2)==str], key=itemgetter(2))) 56 | top_scores.append(top+[dic[top[1]]]) 57 | return top_scores 58 | 59 | def generate_png(df): 60 | for n in range(df.shape[0]): 61 | actual, pred, tani, inpt = df.loc[n] 62 | 63 | mol = oechem.OEGraphMol() 64 | oechem.OESmilesToMol(mol, actual) 65 | simple_png1 = f"{args.png_dest}/experimental_{n}.png" 66 | oedepict.OEPrepareDepiction(mol) 67 | width, height = 600,600 68 | opts = oedepict.OE2DMolDisplayOptions(width,height,oedepict.OEScale_Default*10) 69 | opts.SetTitleLocation(oedepict.OETitleLocation_Hidden) 70 | disp = oedepict.OE2DMolDisplay(mol, opts) 71 | oedepict.OERenderMolecule(simple_png1, disp) 72 | 73 | mol = oechem.OEGraphMol() 74 | oechem.OESmilesToMol(mol, pred) 75 | simple_png2 = f"{args.png_dest}/generated_{n}.png" 76 | oedepict.OEPrepareDepiction(mol) 77 | width, height = 600,600 78 | opts = oedepict.OE2DMolDisplayOptions(width,height,oedepict.OEScale_Default*10) 79 | opts.SetTitleLocation(oedepict.OETitleLocation_Hidden) 80 | disp = oedepict.OE2DMolDisplay(mol, opts) 81 | oedepict.OERenderMolecule(simple_png2, disp) 82 | 83 | mol = oechem.OEGraphMol() 84 | oechem.OESmilesToMol(mol, inpt) 85 | simple_png1 = f"{args.png_dest}/input_{n}.png" 86 | oedepict.OEPrepareDepiction(mol) 87 | width, height = 600,600 88 | opts = oedepict.OE2DMolDisplayOptions(width,height,oedepict.OEScale_Default*10) 89 | opts.SetTitleLocation(oedepict.OETitleLocation_Hidden) 90 | disp = oedepict.OE2DMolDisplay(mol, opts) 91 | oedepict.OERenderMolecule(simple_png1, disp) 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description="Python scripts part of target-specific data pipeline") 95 | parser.add_argument("--in1", dest="infile1", help="actual target molecules", metavar="in1.csv") 96 | parser.add_argument("--in2", dest="infile2", help="pred target molecules", metavar="in2.csv") 97 | parser.add_argument("--smi_col", dest="smi_col", help="name of column with smiles in actual target molecules") 98 | parser.add_argument("--out", dest="outfile", metavar="out.csv") 99 | parser.add_argument("--outdest", dest="outdest") 100 | parser.add_argument("--png_dest", dest="png_dest") 101 | parser.add_argument("--target", dest="target", help="chembl_id of target to exclude", default=None) 102 | parser.add_argument("--target_name", dest="target_name", help="name of target", default=None) 103 | parser.add_argument("--tid", dest="target_dict", help="location of target_dictionary from chembl_id", default=None) 104 | parser.add_argument("--generate_png", dest="generate_png", action='store_true') 105 | parser.add_argument("--data_split", dest="data_split", action='store_true') 106 | args = parser.parse_args() 107 | 108 | if args.data_split: 109 | tid_df = pd.read_csv(args.target_dict, usecols=['tid', 'chembl_id']) 110 | tid_dict = dict(zip(list(tid_df['chembl_id']), list(tid_df['tid']))) 111 | target_id = tid_dict[args.target] 112 | del tid_df 113 | all_df = pd.read_csv(args.alldata, usecols=['chembl_id', 'tid']) 114 | chembl_dict = dict(zip(list(all_df['chembl_id']), list(all_df['tid']))) 115 | del all_df 116 | 117 | 118 | with open(args.infile, 'r') as ifs: 119 | with open(args.outfile, 'w') as ofs: 120 | for line in ifs: 121 | parsed = line.strip().split(',') 122 | if parsed[4] == target_id: 123 | ofs.write(line) 124 | 125 | data_df = pd.read_csv(args.outfile) 126 | temporal_split(data_df, args.outdest, args.target_name) 127 | activity_split(data_df, args.outdest, args.target_name) 128 | 129 | if args.generate_png: 130 | pred = pd.read_csv(args.infile2) 131 | in_out_dict = dict(zip(list(pred['structure']),list(pred['id']))) 132 | actual = pd.read_csv(args.infile1, usecols=[args.smi_col]) 133 | tnm_list = calculate_tanimoto_similarity(list(actual[args.smi_col]), list(pred['structure']), in_out_dict) 134 | tnm_df = pd.DataFrame(tnm_list, columns=['Actual', 'Prediction', 'Tanimoto', 'Input']) 135 | tnm_df.to_csv(args.outfile, index=False) 136 | generate_png(tnm_df) 137 | -------------------------------------------------------------------------------- /scripts/training_curves.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import argparse 5 | import math 6 | 7 | def parsed_info(args): 8 | with open(args.infile, 'r') as ifs: 9 | with open(args.outfile, 'w') as ofs: 10 | ofs.write(f'Job ID,Train Year,Test Year,Batch Size,Filter\n') 11 | for line in ifs: 12 | parsed = line.strip().split('ID ')[1] 13 | job_id = int(parsed.split(' ')[0]) 14 | parsed = parsed.split(' ') 15 | train_year = int(parsed[1].split('=')[1]) 16 | test_year = int(parsed[2].split('=')[1]) 17 | batch = int(parsed[3].split('=')[1]) 18 | filtered = int(parsed[4].split('=')[1]) 19 | ofs.write(f"{job_id},{train_year},{test_year},{batch},{filtered}\n") 20 | 21 | def parsed_val(args): 22 | col_dict = {'validation_accuracy':'Validation accuracy','validation_perplexity':'Validation perplexity'} 23 | 24 | with open(args.infile, 'r') as ifs: 25 | with open(args.outfile, 'w') as ofs: 26 | ofs.write(f'Time,Epoch,{col_dict[args.name]}\n') 27 | n=1 28 | for line in ifs: 29 | parsed = line.strip().split(': ') 30 | ofs.write(f"{parsed[0]},{n},{parsed[1]}\n") 31 | n+=1 32 | 33 | def parsed_training(args): 34 | with open(args.infile, 'r') as ifs: 35 | with open(args.outfile,'w') as ofs: 36 | ofs.write(f'Time,Step,Accuracy,Perplexity\n') 37 | n=0 38 | for line in ifs: 39 | if n==0: 40 | epoch_steps = int(line.strip().split(' every ')[1].split(' steps')[0]) 41 | n=1 42 | elif n==1: 43 | parsed = line.strip().split('; ') 44 | time = parsed[0].split('] ')[0]+']' 45 | try: 46 | step = parsed[0].split(' Step ')[1].split('/')[0] 47 | except IndexError: 48 | print(line) 49 | #print(parsed[0].split(' Step ')[1]) 50 | acc = float(parsed[1][5:]) 51 | ppl = float(parsed[2][5:]) 52 | ofs.write(f"{time},{step},{acc},{ppl}\n") 53 | scores_df = pd.read_csv(args.outfile) 54 | scores_df['Epoch'] = scores_df['Step'].apply(lambda x: math.ceil(x/epoch_steps)) 55 | epoch_scores = scores_df.groupby('Epoch')['Accuracy','Perplexity'].mean() 56 | epoch_scores.to_csv(args.outfile) 57 | return epoch_steps 58 | 59 | def plots(args, triaged=False): 60 | info = pd.read_csv(args.info) 61 | val_acc_df = pd.read_csv(args.val_acc) 62 | val_ppl_df = pd.read_csv(args.val_ppl) 63 | val_acc_df.reset_index(inplace=True) 64 | val_ppl_df.reset_index(inplace=True) 65 | train_df = pd.read_csv(args.train) 66 | models_df = pd.read_csv(args.metrics) 67 | subset = f"split{info.loc[0,'Train Year']}_{info.loc[0,'Test Year']}_filter{info.loc[0,'Filter']}" 68 | models = list(set(models_df[models_df.model_name.notna()]['model_name'])) 69 | models.sort() 70 | if triaged: 71 | triaged_df = pd.read_csv(args.triaged) 72 | triaged_models = list(set(triaged_df[triaged_df.model_name.notna()]['model_name'])) 73 | triaged_models.sort() 74 | 75 | fig, ((ax2), (ax3), (ax4), (ax5), (ax6)) = plt.subplots(5, 1, sharex='col',figsize=(10, 16)) 76 | plt.rcParams.update({'legend.fontsize': 'x-large'}) 77 | ax2.tick_params(axis='y', which='major', labelsize='x-large') 78 | ax3.tick_params(axis='y', which='major', labelsize='x-large') 79 | ax4.tick_params(axis='y', which='major', labelsize='x-large') 80 | ax5.tick_params(axis='y', which='major', labelsize='x-large') 81 | ax6.tick_params(which='major', labelsize='x-large') 82 | 83 | ax2.plot(train_df[train_df['Epoch']<=int(args.epoch_cutoff)]['Epoch'],train_df[train_df['Epoch']<=int(args.epoch_cutoff)]['Perplexity'], label='Perplexity (Train)', color='blue') 84 | ax2.plot(val_ppl_df[val_ppl_df['Epoch']<=int(args.epoch_cutoff)]['Epoch'],val_ppl_df[val_ppl_df['Epoch']<=int(args.epoch_cutoff)]['Validation perplexity'], label='Perplexity (Validation)', color='red') 85 | ax2.set_ylim([0, 3]) 86 | ax2.set_yticks([0,1,2,3]) 87 | 88 | ppl_diff = [val_ppl_df.loc[n,'Validation perplexity'] - train_df.loc[n,'Perplexity'] for n in range(val_ppl_df.shape[0])] 89 | ppl_change = [val_ppl_df.loc[n,'Validation perplexity'] - val_ppl_df.loc[n-1,'Validation perplexity'] if n>0 else 0 for n in range(val_ppl_df.shape[0])] 90 | zipped = zip(ppl_diff, val_ppl_df['Epoch'], ppl_change) 91 | zipped = [x for x in zipped if x[1]<=int(args.epoch_cutoff)] 92 | unzipped = res = [[ i for i, j, k in zipped],[ j for i, j, k in zipped], [ k for i, j, k in zipped]] 93 | ax3.plot(unzipped[1],unzipped[0], label='Perplexity Validation - Train') 94 | ax3.plot(unzipped[1],unzipped[2], label='Delta Perplexity Validation') 95 | ax3.set_ylim([-1, 1]) 96 | ax3.set_yticks([-0.5, 0, 0.5]) 97 | 98 | y = 'Count' 99 | for x in [n for n in models if (type(n)==str) and (subset==n)]: 100 | mask = (models_df.model_name == x)&(models_df[y] != 0)&(models_df.epoch<=int(args.epoch_cutoff)) 101 | df_current = models_df[mask][['epoch',y]] 102 | df_current.sort_values(by='epoch', inplace=True) 103 | ax4.plot(df_current['epoch'], df_current[y], label = y, color='indigo') 104 | if triaged: 105 | for x in [n for n in triaged_models if (type(n)==str) and (subset==n)]: 106 | mask = (triaged_df.model_name == x)&(triaged_df[y] != 0)&(triaged_df.epoch<=int(args.epoch_cutoff)) 107 | df_current = triaged_df[mask][['epoch',y]] 108 | df_current.sort_values(by='epoch', inplace=True) 109 | ax4.plot(df_current['epoch'], df_current[y], label = f'Post-triage {y}', color='orchid') 110 | 111 | y = 'Scaffold Change Count' 112 | for x in [n for n in models if (type(n)==str) and (subset==n)]: 113 | mask = (models_df.model_name == x)&(models_df[y] != 0)&(models_df.epoch<=int(args.epoch_cutoff)) 114 | df_current = models_df[mask][['epoch',y]] 115 | df_current.sort_values(by='epoch', inplace=True) 116 | ax4.plot(df_current['epoch'], df_current[y], label = y, color='forestgreen') 117 | if triaged: 118 | for x in [n for n in triaged_models if (type(n)==str) and (subset==n)]: 119 | mask = (triaged_df.model_name == x)&(triaged_df[y] != 0)&(triaged_df.epoch<=int(args.epoch_cutoff)) 120 | df_current = triaged_df[mask][['epoch',y]] 121 | df_current.sort_values(by='epoch', inplace=True) 122 | ax4.plot(df_current['epoch'], df_current[y], label = f'Post-triage {y}', color='lightgreen') 123 | 124 | y = 'R-Group Change Count' 125 | for x in [n for n in models if (type(n)==str) and (subset==n)]: 126 | mask = (models_df.model_name == x)&(models_df[y] != 0)&(models_df.epoch<=int(args.epoch_cutoff)) 127 | df_current = models_df[mask][['epoch',y]] 128 | df_current.sort_values(by='epoch', inplace=True) 129 | ax5.plot(df_current['epoch'], df_current[y], label = y, color='black') 130 | if triaged: 131 | for x in [n for n in triaged_models if (type(n)==str) and (subset==n)]: 132 | mask = (triaged_df.model_name == x)&(triaged_df[y] != 0)&(triaged_df.epoch<=int(args.epoch_cutoff)) 133 | df_current = triaged_df[mask][['epoch',y]] 134 | df_current.sort_values(by='epoch', inplace=True) 135 | ax5.plot(df_current['epoch'], df_current[y], label = f'Post-triage {y}', color='gray') 136 | 137 | y = 'Unique Scaffolds' 138 | for x in [n for n in models if (type(n)==str) and (subset==n)]: 139 | mask = (models_df.model_name == x)&(models_df[y] != 0)&(models_df.epoch<=int(args.epoch_cutoff)) 140 | df_current = models_df[mask][['epoch',y]] 141 | df_current.sort_values(by='epoch', inplace=True) 142 | ax6.plot(df_current['epoch'], df_current[y], label = y, color='darkred') 143 | if triaged: 144 | for x in [n for n in triaged_models if (type(n)==str) and (subset==n)]: 145 | mask = (triaged_df.model_name == x)&(triaged_df[y] != 0)&(triaged_df.epoch<=int(args.epoch_cutoff)) 146 | df_current = triaged_df[mask][['epoch',y]] 147 | df_current.sort_values(by='epoch', inplace=True) 148 | ax6.plot(df_current['epoch'], df_current[y], label = f'Post-triage {y}', color='lightcoral') 149 | 150 | y = 'New Scaffolds' 151 | for x in [n for n in models if (type(n)==str) and (subset==n)]: 152 | mask = (models_df.model_name == x)&(models_df[y] != 0)&(models_df.epoch<=int(args.epoch_cutoff)) 153 | df_current = models_df[mask][['epoch',y]] 154 | df_current.sort_values(by='epoch', inplace=True) 155 | ax6.plot(df_current['epoch'], df_current[y], label = y, color='magenta') 156 | if triaged: 157 | for x in [n for n in triaged_models if (type(n)==str) and (subset==n)]: 158 | mask = (triaged_df.model_name == x)&(triaged_df[y] != 0)&(triaged_df.epoch<=int(args.epoch_cutoff)) 159 | df_current = triaged_df[mask][['epoch',y]] 160 | df_current.sort_values(by='epoch', inplace=True) 161 | ax6.plot(df_current['epoch'], df_current[y], label = f'Post-triage {y}', color='violet') 162 | 163 | # put legends on subplots 164 | ax2.legend() 165 | ax3.legend() 166 | ax4.legend(loc='lower right') 167 | ax5.legend(loc='lower right') 168 | ax6.legend(loc='lower right') 169 | 170 | # remove vertical gap between subplots 171 | plt.subplots_adjust(hspace=.0) 172 | 173 | ax6.set_xlabel('Epoch Number', fontsize='x-large') 174 | suffix = '' if triaged else '_manuscript' 175 | plt.savefig(f"{args.outpng}/split{info.loc[0,'Train Year']}_{info.loc[0,'Test Year']}_filter{info.loc[0,'Filter']}_batch{info.loc[0,'Batch Size']}{suffix}.png") 176 | 177 | def main(): 178 | parser = argparse.ArgumentParser(description="Make training Curves for trained models") 179 | parser.add_argument("--in", dest="infile", help="err or out txt file from training", metavar="input.txt") 180 | parser.add_argument("--out", dest="outfile", help="output file", metavar="output.csv") 181 | parser.add_argument("--info", dest="info", help="CSV with run info", metavar='info.csv') 182 | parser.add_argument("--name", dest="name", help="Name training/validation metric") 183 | parser.add_argument("--outpng", dest="outpng", help="plot png destination") 184 | parser.add_argument("--parse_type", dest="parse_type") 185 | parser.add_argument("--val_acc", dest="val_acc", help="CSV with validation accuracy scores", metavar="val_acc.csv") 186 | parser.add_argument("--val_ppl", dest="val_ppl", help="CSV with validation perplexity scores", metavar="val_ppl.csv") 187 | parser.add_argument("--train", dest="train", help="CSV with training scores", metavar="train.csv") 188 | parser.add_argument("--metrics", dest="metrics", help="CSV with chemical scores", metavar="metrics.csv") 189 | parser.add_argument("--triaged", dest="triaged", help="CSV with chemical scores of triaged molecules", metavar="triaged.csv") 190 | parser.add_argument("--epoch_cutoff", dest="epoch_cutoff", help="epoch cutoff to graph up until", default=32) 191 | args = parser.parse_args() 192 | 193 | if args.parse_type == 'info': 194 | parsed_info(args) 195 | elif args.parse_type == 'val': 196 | parsed_val(args) 197 | elif args.parse_type == 'train': 198 | epoch_steps = parsed_training(args) 199 | elif args.parse_type == 'plot': 200 | plots(args) 201 | elif args.parse_type == 'plot_w_triaged': 202 | plots(args, triaged=True) 203 | 204 | 205 | if __name__=='__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /slurm_scripts/dataset_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../dataset_split.sh.out 3 | #SBATCH -e .../dataset_split.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 10:00:00 6 | #SBATCH --mem-per-cpu=40G 7 | 8 | export PYTHONPATH=... 9 | export OE_LICENSE=.../oe_license.txt 10 | 11 | mkdir -p $RUN_DIR 12 | cd $RUN_DIR 13 | 14 | ## Dataset split 15 | echo "Starting dataset split ..." 16 | python ./scripts/dataset_split.py \ 17 | --in $DATA_DIR/${DATA_ID}_filtered.csv \ 18 | --timestamps "$TIMESTAMPS" \ 19 | --out $RUN_DIR \ 20 | --year_train $TRAIN_YEAR \ 21 | --year_test $TEST_YEAR \ 22 | --augment 23 | 24 | touch src-test.txt src-val.txt tgt-test.txt tgt-val.txt val.csv test.csv # if all data go to the training subset, this creates empty files for validation and testing 25 | echo "Finished dataset split" 26 | 27 | ## Generate a single file with all validation molecules, each occurring only once 28 | echo "Starting to generate a file with validation molecules ..." 29 | cat $RUN_DIR/src-val.txt $RUN_DIR/tgt-val.txt | sort | uniq > $RUN_DIR/val-unique.txt 30 | echo "Finished a file with validation molecules" 31 | -------------------------------------------------------------------------------- /slurm_scripts/model_scoring.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../model_scoring.sh.out 3 | #SBATCH -e .../model_scoring.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 4:00:00 6 | #SBATCH --mem-per-cpu=50G 7 | #SBATCH --gres=gpu:1 8 | 9 | #export PYTHONPATH=... 10 | export CUDA_VISIBLE_DEVICES=0 11 | 12 | mkdir -p $OUTPUT_DIR 13 | 14 | onmt_translate \ 15 | --model $MODEL_DIR/${MODEL_ID}_epoch_${EPOCH_NUM}.pt \ 16 | --src $SRC_DATA \ 17 | --tgt $TGT_DATA \ 18 | --output $OUTPUT_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt \ 19 | --replace_unk \ 20 | --seed 1 \ 21 | --gpu 0 22 | -------------------------------------------------------------------------------- /slurm_scripts/molecular_properties_histograms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../plots.sh.out 3 | #SBATCH -e .../plots.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 1:00:00 6 | #SBATCH --mem-per-cpu=50G 7 | #SBATCH --gres=gpu:1 8 | 9 | export PYTHONPATH=... 10 | export CUDA_VISIBLE_DEVICES=0 11 | 12 | ## Get molecular properties of input molecules 13 | python ./scripts/molecular_properties.py \ 14 | --in $DATA_DIR/${DATA_ID}_mmpdb_input_nostereo.csv \ 15 | --out $DATA_DIR/${DATA_ID}_molecular_properties.csv \ 16 | --smi_col $SMI_COL 17 | 18 | python ./scripts/molecular_properties.py \ 19 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 20 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 21 | --smi_col 'structure' 22 | 23 | python ./scripts/molecular_properties.py \ 24 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 25 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 26 | --smi_col 'id' 27 | 28 | python ./scripts/generating_plots.py \ 29 | --in1 $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 30 | --in2 $DATA_DIR/${DATA_ID}_molecular_properties.csv \ 31 | --out $PLOT_DIR \ 32 | --type molecular_properties \ 33 | --smi_col $SMI_COL 34 | -------------------------------------------------------------------------------- /slurm_scripts/new_smirks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../newsmirk_count.sh.out 3 | #SBATCH -e .../newsmirk_count.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 1-00:00:00 6 | #SBATCH --mem-per-cpu=40G 7 | 8 | export PYTHONPATH=... 9 | export OE_LICENSE=.../oe_license.txt 10 | 11 | 12 | ## Get new smirks 13 | #shuf $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv | head -${SUBSET_SIZE} > $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_small.csv 14 | 15 | python ./scripts/new_smirks.py \ 16 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 17 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_mmpdb.csv \ 18 | --make_smi_df 19 | 20 | mmpdb fragment \ 21 | --delimiter comma \ 22 | --has-header $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_mmpdb.csv \ 23 | -o $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.fragments 24 | 25 | mmpdb index $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.fragments \ 26 | -o $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_pairs.csv \ 27 | --out 'csv' 28 | 29 | cat $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_pairs.csv | sort | uniq > $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_pairs_unique.csv 30 | 31 | python ./scripts/parsing_pairs.py \ 32 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_pairs_unique.csv \ 33 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_parsed.csv 34 | 35 | # make sure it's worked and if so, delete huge intermediate files 36 | if [ -s $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_parsed.csv ]; then 37 | rm $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_pairs.csv 38 | rm $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_pairs_unique.csv 39 | fi 40 | 41 | python ./scripts/new_smirks.py \ 42 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_parsed.csv \ 43 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_smirks.csv \ 44 | --in2 $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 45 | --get_smirks 46 | 47 | mkdir -p $PNG_DEST 48 | 49 | python ./scripts/new_smirks.py \ 50 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_smirks.csv \ 51 | --in2 $DATA_DIR/${DATA_ID}_counted.csv \ 52 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}_newsmirks.csv \ 53 | --new_smirks \ 54 | --threshold $THRESHOLD \ 55 | --png_dest $PNG_DEST 56 | -------------------------------------------------------------------------------- /slurm_scripts/predictions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../predictions.sh.out 3 | #SBATCH -e .../predictions.sh.err 4 | # #SBATCH -n 5 | #SBATCH -t 1:00:00 6 | #SBATCH --mem-per-cpu=25G 7 | #SBATCH --gres=gpu:1 8 | 9 | export PYTHONPATH=... 10 | export OE_LICENSE=.../oe_license.txt 11 | 12 | export CUDA_VISIBLE_DEVICES=0 13 | 14 | mkdir -p $OUTPUT_DIR 15 | 16 | echo 'Starting Translate ...' 17 | if [ ! -s $OUTPUT_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt ]; then 18 | onmt_translate \ 19 | --model $MODEL_DIR/${MODEL_ID}_epoch_${EPOCH_NUM}.pt \ 20 | --src $RUN_DIR/val-unique.txt \ 21 | --output $OUTPUT_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt \ 22 | --replace_unk \ 23 | --seed 1 \ 24 | --gpu 0 25 | fi 26 | echo 'Finished Translate' 27 | 28 | 29 | ## Convert selfies to smiles 30 | echo "Starting SELFILES to SMILES ..." 31 | python ./scripts/selfies_to_smiles.py \ 32 | --in1 $OUTPUT_DIR/pred_selfies_epoch_${EPOCH_NUM}.txt \ 33 | --in2 $RUN_DIR/val-unique.txt \ 34 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv 35 | echo "Finished SELFILES to SMILES" 36 | 37 | python ./scripts/scaffolding.py \ 38 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 39 | --out $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 40 | --column1 'structure' \ 41 | --column2 'id' 42 | 43 | ## Score model predictions 44 | python ./scripts/scoring.py \ 45 | --in $OUTPUT_DIR/pred_smiles_epoch_${EPOCH_NUM}.csv \ 46 | --metrics_table $METRICS_TABLE \ 47 | --training_data $RUN_DIR/train.csv \ 48 | --model ${MODEL_ID}_epoch_${EPOCH_NUM} \ 49 | --change_count \ 50 | --scaffolds 51 | 52 | # ## Generate pngs of predicted smiles 53 | # singularity run --nv $SINGULARITY_CONTAINER mkdir /gpfs/workspace/users/tysine/images/${MODEL_ID}_epoch_$EPOCH_NUM 54 | 55 | # singularity run --nv $SINGULARITY_CONTAINER shuf $DATA_DIR/pred_smiles_epoch_${EPOCH_NUM}_unique.csv | head -100 > $DATA_DIR/pred_smiles_epoch_${EPOCH_NUM}_100.csv 56 | 57 | # singularity run --nv $SINGULARITY_CONTAINER python /gpfs/workspace/users/tysine/Transformer/visualize_smiles.py \ 58 | # --in $DATA_DIR/pred_smiles_epoch_${EPOCH_NUM}_100.csv \ 59 | # --out /gpfs/workspace/users/tysine/images/${MODEL_ID}_epoch_$EPOCH_NUM 60 | -------------------------------------------------------------------------------- /slurm_scripts/preparing_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../preparing_data.sh.out 3 | #SBATCH -e .../preparing_data.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 20:00:00 6 | #SBATCH --mem-per-cpu=40G 7 | 8 | export PYTHONPATH=... 9 | export OE_LICENSE=.../oe_license.txt 10 | 11 | mkdir -p $DATA_DIR 12 | 13 | ## Preparing MMPDB input file 14 | python ./scripts/mmpdb_prep.py \ 15 | --in $DATA \ 16 | --out $DATA_DIR/${DATA_ID}_mmpdb_input.csv \ 17 | --smiles $SMI_COL \ 18 | --ids $ID_COL 19 | 20 | ## Clear stereochemistry 21 | python ./scripts/clear_stereochemistry.py \ 22 | --in $DATA_DIR/${DATA_ID}_mmpdb_input.csv \ 23 | --out $DATA_DIR/${DATA_ID}_mmpdb_input_nostereo.csv 24 | 25 | ## MMPDB Pairing 26 | echo "Starting fragmentation ..." 27 | mmpdb fragment \ 28 | --delimiter comma \ 29 | --has-header $DATA_DIR/${DATA_ID}_mmpdb_input_nostereo.csv \ 30 | -o $DATA_DIR/${DATA_ID}.fragments 31 | echo "Finished fragmentation" 32 | 33 | echo "Starting pairing..." 34 | mmpdb index $DATA_DIR/${DATA_ID}.fragments \ 35 | -o $DATA_DIR/${DATA_ID}_pairs.csv \ 36 | --out 'csv' 37 | echo "Finished pairing" 38 | 39 | echo "Starting parsing pairs file ..." 40 | python ./scripts/parsing_pairs.py \ 41 | --in $DATA_DIR/${DATA_ID}_pairs.csv \ 42 | --out $DATA_DIR/${DATA_ID}_pairs_parsed.csv 43 | echo "Finished parsing" 44 | 45 | rm $DATA_DIR/${DATA_ID}_pairs.csv 46 | 47 | echo "Starting counting smirks ..." 48 | ## Counting smirks 49 | python ./scripts/counting_smirks.py \ 50 | --in $DATA_DIR/${DATA_ID}_pairs_parsed.csv\ 51 | --out $DATA_DIR/${DATA_ID}_counted.csv 52 | echo "Finished counting smirks" 53 | 54 | echo "Starting filtering data ..." 55 | ## Filtering data 56 | python ./scripts/filtering_data.py \ 57 | --in $DATA_DIR/${DATA_ID}_pairs_parsed.csv \ 58 | --all $DATA \ 59 | --smirks $DATA_DIR/${DATA_ID}_counted.csv \ 60 | --out $DATA_DIR/${DATA_ID}_filtered.csv \ 61 | --size $SAMPLE_SIZE \ 62 | --exclude $EXCLUDE 63 | echo "Finished filtering data" 64 | -------------------------------------------------------------------------------- /slurm_scripts/training_curves.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../training_curves.sh.out 3 | #SBATCH -e .../training_curves.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 1:00:00 6 | #SBATCH --mem-per-cpu=40G 7 | 8 | export PYTHONPATH=... 9 | 10 | mkdir -p $OUTDEST 11 | 12 | python ./scripts/training_curves.py \ 13 | --in $IN_FILE \ 14 | --out $OUTDEST/training_model_info_${MODEL_NUM}.csv \ 15 | --parse_type info 16 | 17 | grep 'Validation accuracy' $ERR_FILE > $OUTDEST/${NAME}_${MODEL_NUM}.err 18 | 19 | python ./scripts/training_curves.py \ 20 | --in $OUTDEST/${NAME}_${MODEL_NUM}.err \ 21 | --out $OUTDEST/${NAME}_${MODEL_NUM}.csv \ 22 | --name $NAME \ 23 | --parse_type val 24 | 25 | export NAME=validation_perplexity 26 | 27 | grep 'Validation perplexity' $ERR_FILE > $OUTDEST/${NAME}_${MODEL_NUM}.err 28 | 29 | python ./scripts/training_curves.py \ 30 | --in $OUTDEST/${NAME}_${MODEL_NUM}.err \ 31 | --out $OUTDEST/${NAME}_${MODEL_NUM}.csv \ 32 | --name $NAME \ 33 | --parse_type val 34 | 35 | export NAME=training 36 | 37 | grep 'Start training loop and validate' $ERR_FILE > $OUTDEST/${NAME}_${MODEL_NUM}.err 38 | grep 'acc:' $ERR_FILE >> $OUTDEST/${NAME}_${MODEL_NUM}.err 39 | 40 | python ./scripts/training_curves.py \ 41 | --in $OUTDEST/${NAME}_${MODEL_NUM}.err \ 42 | --out $OUTDEST/${NAME}_${MODEL_NUM}.csv \ 43 | --name $NAME \ 44 | --outpng $OUTDEST \ 45 | --parse_type train 46 | 47 | mkdir -p $PLOTDEST 48 | python /gpfs/workspace/users/tysine/Transformer/training_curves.py \ 49 | --val_acc $OUTDEST/validation_accuracy_${MODEL_NUM}.csv \ 50 | --val_ppl $OUTDEST/validation_perplexity_${MODEL_NUM}.csv \ 51 | --train $OUTDEST/training_${MODEL_NUM}.csv \ 52 | --info $OUTDEST/training_model_info_${MODEL_NUM}.csv \ 53 | --metrics $METRICS_TABLE \ 54 | --outpng $PLOTDEST \ 55 | --epoch_cutoff $EPOCH_CUTOFF \ 56 | --parse_type plot 57 | -------------------------------------------------------------------------------- /slurm_scripts/training_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -o .../training.sh.out 3 | #SBATCH -e .../training.sh.err 4 | #SBATCH -n 1 5 | #SBATCH -t 4-00:00:00 6 | #SBATCH --mem-per-cpu=50G 7 | #SBATCH --gres=gpu:1 8 | 9 | #export PYTHONPATH=... 10 | export CUDA_VISIBLE_DEVICES=0 11 | 12 | mkdir -p $MODEL_DIR 13 | 14 | ## Calculate Train size, train steps and save steps 15 | export TRAIN_SIZE=$(cat $RUN_DIR/src-train.txt | wc -l) 16 | export EPOCH_STEPS=$(($TRAIN_SIZE/$BATCH_SIZE)) 17 | export TRAIN_STEPS=$(($EPOCH_STEPS*$TRAIN_EPOCHS)) 18 | export SAVE_STEPS=$(($EPOCH_STEPS*$SAVE_EPOCHS)) 19 | 20 | # if validation set is not empty, run validation at the end of each epoch; otherwise, don't run during training 21 | export VS=$(head $RUN_DIR/src-val.txt | wc -l) 22 | if [ $VS -eq 0 ]; then 23 | export VALID_STEPS=$(($TRAIN_STEPS+1)) 24 | else 25 | export VALID_STEPS=$EPOCH_STEPS 26 | fi 27 | 28 | ## Build config yaml 29 | cat << EOF > $RUN_DIR/config.yaml 30 | ## Where the vocab(s) will be written 31 | src_vocab: $RUN_DIR/vocab.src 32 | tgt_vocab: $RUN_DIR/vocab.tgt 33 | 34 | # Corpus opts: 35 | data: 36 | corpus_1: 37 | path_src: $RUN_DIR/src-train.txt 38 | path_tgt: $RUN_DIR/tgt-train.txt 39 | valid: 40 | path_src: $RUN_DIR/src-val.txt 41 | path_tgt: $RUN_DIR/tgt-val.txt 42 | EOF 43 | 44 | ## OpenNMT preprocessing data 45 | onmt_build_vocab \ 46 | -config $RUN_DIR/config.yaml \ 47 | -save_data $RUN_DIR/data \ 48 | -n_samples -1 49 | 50 | ## OpenNMT Training Model 51 | onmt_train \ 52 | -config $RUN_DIR/config.yaml \ 53 | -save_model $MODEL_DIR/$MODEL_ID \ 54 | -train_steps $TRAIN_STEPS \ 55 | -valid_steps $VALID_STEPS \ 56 | -save_checkpoint_steps $SAVE_STEPS \ 57 | -batch_size $BATCH_SIZE \ 58 | -world_size 1 \ 59 | -gpu_ranks 0 60 | 61 | ## Rename models from steps to epochs 62 | python ./scripts/renaming_models.py \ 63 | --models $MODEL_DIR \ 64 | --batch_size $BATCH_SIZE \ 65 | --train_size $TRAIN_SIZE \ 66 | --startswith "${MODEL_ID}_" 67 | -------------------------------------------------------------------------------- /target_data_split.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfizer-opensource/transform-molecules/de569e232698f17eca27d4f802792c90300509d5/target_data_split.zip --------------------------------------------------------------------------------