├── LICENSE ├── README.md ├── data ├── gen_smi │ └── README.md ├── prepare_data.py ├── protein.pdbqt └── synthons │ ├── data.csv │ └── synthons.csv ├── environment.yaml ├── figure.py ├── images └── figure.png ├── model.py ├── run.py ├── train_combiner.py ├── train_inpainting.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 mywang1994 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClickGen: Directed Exploration of Synthesizable Chemical Space Leading to the Rapid Synthesis of Novel and Active Lead Compounds via Modular Reactions and Reinforcement Learning 2 | 3 | ![overview of the architecture of ClickGen](/images/figure.png) 4 | 5 | ## Overview 6 | This repository contains the source of ClickGen, a deep learning model that utilizes modular reactions like click chemistry to assemble molecules and incorporates reinforcement learning to ensure that the proposed molecules display high diversity, novelty and strong binding tendency. 7 | 8 | ## Abstract 9 | 10 | Despite the vast potential of generative models, the severe challenge in low synthesizability of many generated molecules has restricted their potential impacts in real-world scenarios. In response to this issue, we develop ClickGen, a deep learning model that utilizes modular reactions like click chemistry to assemble molecules and incorporates reinforcement learning along with inpainting technique to ensure that the proposed molecules display high diversity, novelty and strong binding tendency. We then further conducted wet-lab validation on ClickGen’s proposed molecules for PARP1. Due to the guaranteed high synthesizability and model-generated synthetic routes for reference, we successfully produced and tested the bioactivity of these novel compounds in just 20 days, much faster than typically expected time frame when handling sufficiently novel molecules. In bioactivity assays, two lead compounds demonstrated superior anti-proliferative efficacy against cancer cell lines, low toxicity, and nanomolar-level inhibitory activity to PARP1. We anticipate that ClickGen and related models could potentially signify a new paradigm in molecular generation, advancing the future of AI and automated experimentation-driven closed-loop molecular design closer to reality. 11 | 12 | 13 | ## System Requirements 14 | 15 | ### Hardware requirements 16 | 17 | `ClickGen` software is recommended for use on computers with more than 20GB of VRAM or RAM 18 | 19 | ### OS Requirements 20 | This package is supported for *Linux* and *Windows*. The package has been tested on the following systems: 21 | + Windows: Windows 11 23H2 22 | + Linux: Ubuntu 22.04 23 | 24 | ### Software requirements 25 | 26 | - Python == 3.7 27 | - pytorch >= 1.1.0 28 | - openbabel == 2.4.1 29 | - RDKit == 2020.09.5 30 | - autodock vina (for python) [README](https://autodock-vina.readthedocs.io/en/latest/docking_python.html) 31 | - openbabel >= 3.1.1 32 | 33 | 34 | if utilizing GPU accelerated model training 35 | - CUDA==10.2 & cudnn==7.5 36 | 37 | ## Install from Github & Creat a new environment in conda 38 | ``` 39 | git clone https://github.com/mywang1994/cligen_gen 40 | cd cligen_gen 41 | conda env create -f environment.yaml 42 | conda activate click_gen 43 | ``` 44 | 45 | 46 | 47 | ## Running ClickGen 48 | ### 1.Prepare synthons dataset 49 | The ClickGen model requires labeled reactants, stored in data files in `.csv` format, as well as protein structures that have been energy-minimized and repaired, saved in `.pdbqt` format. Finally, a standardized SMILES format for the initial synthons fragment is essential, with the annotation method detailed in `./data/prepare_data.py` or as described in the data preparation section of the article. 50 | 51 | ``` 52 | python ./data/prepare_data.py --input_path SMILES.csv # the path of SMILES .csv files 53 | --output_path output.csv # the path of save synthons .csv files 54 | ``` 55 | 56 | ### 2.Train the Reaction-based combiner 57 | 58 | To train the Reaction-based combiner, it is necessary to utilize a SMILES dataset along with the synthon dataset obtained in step 1. It is essential to define the number of positive and negative samples, as well as some fundamental model parameters, such as the learning rate. 59 | 60 | 61 | 62 | ``` 63 | python train_combiner.py --mol_p SMILES.csv # the path of SMILES .csv files 64 | --syn_path synthons.csv # the path of labeled synthons .csv files 65 | --num_p 100 # the number of positive_samples 66 | --num_n 1000 # the number of negative_samples 67 | --lr 1e-4 # the learning rate 68 | --epoch 80 # the training epoches 69 | ``` 70 | Ultimately, the model file will be stored in the `./data/model/ `directory. 71 | 72 | 73 | 74 | 75 | ### 3.Train the Inpainting-based generator 76 | 77 | Training the Inpainting-based generator model requires the dataset created in step 1, along with the input model parameters such as embedding and hidden dimensions. Users can also configure the model's skip connections and attention mechanisms flexibly via the command line, allowing the model to be adjusted according to different needs. Additionally, hardware requirements include at least 20GB of GPU memory or CPU memory (not recommended due to slower training speed). 78 | 79 | ``` 80 | python train_inpainting.py --mol_p SMILES.csv # the path of SMILES .csv files 81 | --embed_dim 64 # the embedding size 82 | --hid_dim 100 # the hidden dimension 83 | --skip_connection 1000 # skip connection 84 | --attention 1e-4 # attention mechanism 85 | --lr 1e-4 # the learning rate 86 | --epoch 80 # the training epoches 87 | 88 | ``` 89 | Ultimately, the model file will be stored in the `./data/model/ `directory. 90 | 91 | 92 | 93 | 94 | ### 4.Run the ClickGen 95 | 96 | 97 | To run the ClickGen model, you need to use the dataset obtained in step 1, as well as the Inpainting-based generator and Reaction-based combiner trained in steps 2 and 3. You also need the starting synthons (which can be omitted in inpainting mode), the corresponding protein target pdb structure, and the input parameters for the model, such as the number of molecules to be generated and the parameters for the Inpainting-based generator and Reaction-based combiner. 98 | 99 | ``` 100 | python run.py --inpainting Trur/False # use the inpainting module 101 | --input [3*]NC1CCCC(N[3*])CC1 # Initial synthon fragment. If using inpainting mode, no need to input the initial synthon. 102 | --syn_p synthons.csv # the path of labeled synthons 103 | --protein ./data/parp1.pdb # protein 104 | --num_sims 10000 # simulation steps 105 | ``` 106 | 107 | Based on our tests, generating 10,000 molecules with ClickGen takes between 0.5 to 1.5 hours, depending on the system and hardware configuration. 108 | 109 | ## License 110 | 111 | This project is covered under the MIT License. 112 | 113 | 114 | ## Need Help? 115 | 116 | If you encounter any issues, feel free to contact us. 117 | -------------------------------------------------------------------------------- /data/gen_smi/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | This file contains the molecules we generated, which are suitable for reproducing the charts in `figure.py`. 4 | 5 | Since the unzipped file is large (over 10 GB), you can download it from the following link: 6 | [Download Data](https://pan.baidu.com/s/1toD7SaxXYNbdDIOjNE14vQ?pwd=sctj) 7 | 8 | 9 | ### Data Structure 10 | 11 | 1. **Target Active Molecules**: The files `aa2ar.csv`, `sars.csv`, and `rock1.csv` contain SMILES representations of active molecules for each target. 12 | 13 | 2. **Generated Molecules by Baseline Models**: The following six files contain molecules generated by two baseline models for the three targets. Each file includes a `['SMILES']` column: 14 | - `syn_sars.csv` 15 | - `bbar_sars.csv` 16 | - `syn_aa2ar.csv` 17 | - `bbar_aa2ar.csv` 18 | - `syn_rock.csv` 19 | - `bbar_rock.csv` 20 | 21 | 3. **Protein Files**: There are 60,000 protein files in total, each generated by either `clickgen` or `clickgen-inpainting`. For each of the three targets, there are 10,000 molecules per model. Each file is annotated with properties such as synthesis steps, scores, synthetic intermediates, ligand efficiency (LE), and similarity to active molecules. 22 | 23 | -------------------------------------------------------------------------------- /data/prepare_data.py: -------------------------------------------------------------------------------- 1 | from utils import split_molecule 2 | 3 | import pandas as pd 4 | from rdkit import Chem 5 | from rdkit.Chem import rdChemReactions 6 | from tqdm import tqdm 7 | import argparse 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Prepare synthons ') 11 | 12 | 13 | parser.add_argument('--input_path', type=str, default='[3*]NC1CCCC(N[3*])CC1',help='input smiles path',) 14 | parser.add_argument('--output_path', type=str, default='./data/synthons/synthons.csv', help='the path of synthons library') 15 | 16 | 17 | 18 | 19 | args = parser.parse_args() 20 | 21 | all_synthons=set() 22 | 23 | data=pd.read_csv(args.input_path) 24 | 25 | for smiles in tqdm(data['SMILES']): 26 | synthons = split_molecule(smiles) 27 | all_synthons.update(synthons) 28 | 29 | 30 | 31 | 32 | pd.DataFrame({'Synthons': list(all_synthons)}).to_csv(args.output_path, index=False) -------------------------------------------------------------------------------- /data/synthons/data.csv: -------------------------------------------------------------------------------- 1 | 2 | Due to copyright issues with Enamine, you need to register and declare that the data use is for non-commercial purposes 3 | 4 | https://enamine.net/compound-collections/real-compounds/real-database-subsets 5 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: clickgen 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - https://repo.anaconda.com/pkgs/main 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - alabaster=0.7.13=pyhd8ed1ab_0 13 | - babel=2.12.1=pyhd8ed1ab_1 14 | - blas=1.0=mkl 15 | - boost=1.74.0=py37h796e4cb_5 16 | - boost-cpp=1.74.0=h75c5d50_8 17 | - bottleneck=1.3.5=py37hda87dfa_0 18 | - brotlipy=0.7.0=py37h27cfd23_1003 19 | - bzip2=1.0.8=h7b6447c_0 20 | - ca-certificates=2023.05.30=h06a4308_0 21 | - cairo=1.16.0=h19f5f5c_2 22 | - certifi=2023.5.7=pyhd8ed1ab_0 23 | - cffi=1.15.1=py37h74dc2b5_0 24 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 25 | - colorama=0.4.6=pyhd8ed1ab_0 26 | - cryptography=38.0.1=py37h9ce1e76_0 27 | - cuda=11.6.2=0 28 | - cuda-cccl=11.6.55=hf6102b2_0 29 | - cuda-command-line-tools=11.6.2=0 30 | - cuda-compiler=11.6.2=0 31 | - cuda-cudart=11.6.55=he381448_0 32 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 33 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 34 | - cuda-cupti=11.6.124=h86345e5_0 35 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 36 | - cuda-driver-dev=11.6.55=0 37 | - cuda-gdb=11.8.86=0 38 | - cuda-libraries=11.6.2=0 39 | - cuda-libraries-dev=11.6.2=0 40 | - cuda-memcheck=11.8.86=0 41 | - cuda-nsight=11.8.86=0 42 | - cuda-nsight-compute=11.8.0=0 43 | - cuda-nvcc=11.6.124=hbba6d2d_0 44 | - cuda-nvdisasm=11.8.86=0 45 | - cuda-nvml-dev=11.6.55=haa9ef22_0 46 | - cuda-nvprof=11.8.87=0 47 | - cuda-nvprune=11.6.124=he22ec0a_0 48 | - cuda-nvrtc=11.6.124=h020bade_0 49 | - cuda-nvrtc-dev=11.6.124=h249d397_0 50 | - cuda-nvtx=11.6.124=h0630a44_0 51 | - cuda-nvvp=11.8.87=0 52 | - cuda-runtime=11.6.2=0 53 | - cuda-samples=11.6.101=h8efea70_0 54 | - cuda-sanitizer-api=11.8.86=0 55 | - cuda-toolkit=11.6.2=0 56 | - cuda-tools=11.6.2=0 57 | - cuda-visual-tools=11.6.2=0 58 | - cycler=0.11.0=pyhd8ed1ab_0 59 | - docutils=0.16=py37h89c1867_3 60 | - expat=2.4.9=h6a678d5_0 61 | - ffmpeg=4.3=hf484d3e_0 62 | - fftw=3.3.9=h27cfd23_1 63 | - fontconfig=2.14.1=hc2a2eb6_0 64 | - freetype=2.12.1=h4a9f257_0 65 | - gds-tools=1.4.0.31=0 66 | - giflib=5.2.1=h7b6447c_0 67 | - glib=2.69.1=h4ff587b_1 68 | - gmp=6.2.1=h295c915_3 69 | - gnutls=3.6.15=he1e5248_0 70 | - icu=70.1=h27087fc_0 71 | - idna=3.4=py37h06a4308_0 72 | - imagesize=1.4.1=pyhd8ed1ab_0 73 | - intel-openmp=2021.4.0=h06a4308_3561 74 | - jinja2=3.1.2=pyhd8ed1ab_1 75 | - jpeg=9e=h7f8727e_0 76 | - kiwisolver=1.4.2=py37h295c915_0 77 | - lame=3.100=h7b6447c_0 78 | - lcms2=2.12=h3be6417_0 79 | - ld_impl_linux-64=2.38=h1181459_1 80 | - lerc=3.0=h295c915_0 81 | - libcublas=11.11.3.6=0 82 | - libcublas-dev=11.11.3.6=0 83 | - libcufft=10.9.0.58=0 84 | - libcufft-dev=10.9.0.58=0 85 | - libcufile=1.4.0.31=0 86 | - libcufile-dev=1.4.0.31=0 87 | - libcurand=10.3.0.86=0 88 | - libcurand-dev=10.3.0.86=0 89 | - libcusolver=11.4.1.48=0 90 | - libcusolver-dev=11.4.1.48=0 91 | - libcusparse=11.7.5.86=0 92 | - libcusparse-dev=11.7.5.86=0 93 | - libdeflate=1.8=h7f8727e_5 94 | - libffi=3.3=he6710b0_2 95 | - libgcc-ng=12.2.0=h65d4601_19 96 | - libgfortran-ng=11.2.0=h00389a5_1 97 | - libgfortran5=11.2.0=h1234567_1 98 | - libiconv=1.16=h7f8727e_2 99 | - libidn2=2.3.2=h7f8727e_0 100 | - libnpp=11.8.0.86=0 101 | - libnpp-dev=11.8.0.86=0 102 | - libnvjpeg=11.9.0.86=0 103 | - libnvjpeg-dev=11.9.0.86=0 104 | - libpng=1.6.37=hbc83047_0 105 | - libstdcxx-ng=11.2.0=h1234567_1 106 | - libtasn1=4.16.0=h27cfd23_0 107 | - libtiff=4.4.0=hecacb30_2 108 | - libunistring=0.9.10=h27cfd23_0 109 | - libuuid=2.32.1=h7f98852_1000 110 | - libwebp=1.2.4=h11a3e52_0 111 | - libwebp-base=1.2.4=h5eee18b_0 112 | - libxcb=1.15=h7f8727e_0 113 | - libxml2=2.9.14=h22db469_4 114 | - libzlib=1.2.13=h166bdaf_4 115 | - llvm-openmp=15.0.5=he0ac6c6_0 116 | - lz4-c=1.9.3=h295c915_1 117 | - markupsafe=2.1.1=py37h7f8727e_0 118 | - matplotlib-base=3.4.3=py37h1058ff1_2 119 | - mkl=2021.4.0=h06a4308_640 120 | - mkl-service=2.4.0=py37h7f8727e_0 121 | - mkl_fft=1.3.1=py37hd3c417c_0 122 | - mkl_random=1.2.2=py37h51133e4_0 123 | - ncurses=6.3=h5eee18b_3 124 | - nettle=3.7.3=hbbd107a_1 125 | - nsight-compute=2022.3.0.22=0 126 | - numexpr=2.8.4=py37he184ba9_0 127 | - numpy=1.21.5=py37h6c91a56_3 128 | - numpy-base=1.21.5=py37ha15fc14_3 129 | - openbabel=3.1.1=py37h6aa62a1_3 130 | - openh264=2.1.1=h4ff587b_0 131 | - openssl=1.1.1u=h7f8727e_0 132 | - packaging=21.3=pyhd8ed1ab_0 133 | - pandas=1.3.5=py37h8c16a72_0 134 | - pcre=8.45=h9c3ff4c_0 135 | - pillow=9.2.0=py37hace64e9_1 136 | - pixman=0.40.0=h36c2ea0_0 137 | - pycairo=1.21.0=py37h0afab05_1 138 | - pycparser=2.21=pyhd3eb1b0_0 139 | - pyg=2.3.0=py37_torch_1.13.0_cu116 140 | - pygments=2.15.1=pyhd8ed1ab_0 141 | - pyopenssl=22.0.0=pyhd3eb1b0_0 142 | - pyparsing=3.0.9=pyhd8ed1ab_0 143 | - pysocks=1.7.1=py37_1 144 | - python=3.7.15=haa1d7c7_0 145 | - python-dateutil=2.8.2=pyhd8ed1ab_0 146 | - python_abi=3.7=2_cp37m 147 | - pytorch=1.13.0=py3.7_cuda11.6_cudnn8.3.2_0 148 | - pytorch-cuda=11.6=h867d48c_0 149 | - pytorch-mutex=1.0=cuda 150 | - pytz=2022.6=pyhd8ed1ab_0 151 | - rdkit=2022.03.2=py37hc52db9c_0 152 | - readline=8.2=h5eee18b_0 153 | - reportlab=3.5.68=py37h69800bb_1 154 | - requests=2.28.1=py37h06a4308_0 155 | - scikit-learn=1.0.2=py37h51133e4_1 156 | - scipy=1.7.3=py37h6c91a56_2 157 | - setuptools=65.5.0=py37h06a4308_0 158 | - six=1.16.0=pyhd3eb1b0_1 159 | - snowballstemmer=2.2.0=pyhd8ed1ab_0 160 | - sphinx=5.3.0=pyhd8ed1ab_0 161 | - sphinx_rtd_theme=1.2.2=pyha770c72_0 162 | - sphinxcontrib-applehelp=1.0.4=pyhd8ed1ab_0 163 | - sphinxcontrib-devhelp=1.0.2=py_0 164 | - sphinxcontrib-htmlhelp=2.0.1=pyhd8ed1ab_0 165 | - sphinxcontrib-jquery=4.1=pyhd8ed1ab_0 166 | - sphinxcontrib-jsmath=1.0.1=py_0 167 | - sphinxcontrib-qthelp=1.0.3=py_0 168 | - sphinxcontrib-serializinghtml=1.1.5=pyhd8ed1ab_2 169 | - sqlalchemy=1.3.24=py37h540881e_1 170 | - sqlite=3.39.3=h5082296_0 171 | - swig=4.0.2=hd3c618e_2 172 | - threadpoolctl=2.2.0=pyh0d69192_0 173 | - tk=8.6.12=h1ccaba5_0 174 | - torchaudio=0.13.0=py37_cu116 175 | - torchvision=0.14.0=py37_cu116 176 | - tornado=6.2=py37h540881e_0 177 | - tqdm=4.64.1=py37h06a4308_0 178 | - typing_extensions=4.3.0=py37h06a4308_0 179 | - urllib3=1.26.12=py37h06a4308_0 180 | - wheel=0.37.1=pyhd3eb1b0_0 181 | - xz=5.2.6=h5eee18b_0 182 | - zlib=1.2.13=h166bdaf_4 183 | - zstd=1.5.2=ha4553b6_0 184 | - pip: 185 | - absl-py==1.3.0 186 | - anyio==3.7.1 187 | - argon2-cffi==23.1.0 188 | - argon2-cffi-bindings==21.2.0 189 | - astunparse==1.6.3 190 | - attrs==23.1.0 191 | - backcall==0.2.0 192 | - beautifulsoup4==4.12.2 193 | - bleach==6.0.0 194 | - blessed==1.19.1 195 | - cachetools==5.2.0 196 | - comm==0.1.4 197 | - debugpy==1.7.0 198 | - decorator==5.1.1 199 | - defusedxml==0.7.1 200 | - entrypoints==0.4 201 | - exceptiongroup==1.1.3 202 | - fastjsonschema==2.18.0 203 | - fcd==1.1 204 | - flatbuffers==22.12.6 205 | - gast==0.4.0 206 | - google-auth==2.15.0 207 | - google-auth-oauthlib==0.4.6 208 | - google-pasta==0.2.0 209 | - gpustat==1.0.0 210 | - grpcio==1.51.1 211 | - guacamol==0.5.4 212 | - h5py==3.7.0 213 | - importlib-metadata==5.2.0 214 | - importlib-resources==5.12.0 215 | - ipykernel==6.16.2 216 | - ipython==7.34.0 217 | - ipython-genutils==0.2.0 218 | - ipywidgets==8.1.0 219 | - jedi==0.19.0 220 | - joblib==1.2.0 221 | - jsonschema==4.17.3 222 | - jupyter==1.0.0 223 | - jupyter-client==7.4.9 224 | - jupyter-console==6.6.3 225 | - jupyter-core==4.12.0 226 | - jupyter-server==1.24.0 227 | - jupyterlab-pygments==0.2.2 228 | - jupyterlab-widgets==3.0.8 229 | - keras==2.11.0 230 | - libclang==14.0.6 231 | - markdown==3.4.1 232 | - matplotlib-inline==0.1.6 233 | - mistune==3.0.1 234 | - nbclassic==1.0.0 235 | - nbclient==0.7.4 236 | - nbconvert==7.6.0 237 | - nbformat==5.8.0 238 | - nest-asyncio==1.5.7 239 | - networkx==2.6.3 240 | - notebook==6.5.5 241 | - notebook-shim==0.2.3 242 | - nvidia-ml-py==11.495.46 243 | - oauthlib==3.2.2 244 | - opt-einsum==3.3.0 245 | - pandocfilters==1.5.0 246 | - parso==0.8.3 247 | - pexpect==4.8.0 248 | - pickleshare==0.7.5 249 | - pip==23.3.1 250 | - pkgutil-resolve-name==1.3.10 251 | - prometheus-client==0.17.1 252 | - prompt-toolkit==3.0.39 253 | - protobuf==3.19.6 254 | - psutil==5.9.4 255 | - ptyprocess==0.7.0 256 | - pyasn1==0.4.8 257 | - pyasn1-modules==0.2.8 258 | - pyrsistent==0.19.3 259 | - python-ternary==1.0.8 260 | - pyzmq==24.0.1 261 | - qtconsole==5.4.4 262 | - qtpy==2.4.0 263 | - rdkit-pypi==2022.9.3 264 | - requests-oauthlib==1.3.1 265 | - rsa==4.9 266 | - seaborn==0.12.2 267 | - send2trash==1.8.2 268 | - sniffio==1.3.0 269 | - soupsieve==2.4.1 270 | - tensorboard==2.11.0 271 | - tensorboard-data-server==0.6.1 272 | - tensorboard-plugin-wit==1.8.1 273 | - tensorflow==2.11.0 274 | - tensorflow-estimator==2.11.0 275 | - tensorflow-io-gcs-filesystem==0.29.0 276 | - termcolor==2.1.1 277 | - terminado==0.17.1 278 | - tinycss2==1.2.1 279 | - traitlets==5.9.0 280 | - vina==1.2.5 281 | - wcwidth==0.2.5 282 | - webencodings==0.5.1 283 | - websocket-client==1.6.1 284 | - werkzeug==2.2.2 285 | - widgetsnbextension==4.0.8 286 | - wrapt==1.14.1 287 | - zipp==3.11.0 288 | prefix: /root/anaconda3/envs/py37 289 | -------------------------------------------------------------------------------- /figure.py: -------------------------------------------------------------------------------- 1 | ####Here is a script that can reproduce the statistical charts in the article, including those appearing in the main text. Please note the following: 2 | ##1.The script requires gen_smi data to run. 3 | ##It is recommended to run and save each chart individually, as running multiple charts at once may cause bugs. 4 | ##The tab20 color palette in sns.set() may render differently across computers. If there are difficulties with chart colors, please try a different color palette. 5 | 6 | 7 | import os 8 | import re 9 | import random 10 | import joblib 11 | import numpy as np 12 | import pandas as pd 13 | import sklearn 14 | from sklearn.datasets import make_blobs 15 | from sklearn.manifold import TSNE 16 | import matplotlib as mpl 17 | import matplotlib.pyplot as plt 18 | import seaborn as sns 19 | from rdkit import Chem 20 | from rdkit.Chem import AllChem 21 | from scscore.standalone_model_numpy import SCScorer 22 | from syba.syba import Syba 23 | from gasa import GASA 24 | import RAscore as ra 25 | from vina import Vina 26 | 27 | 28 | 29 | sns.set(palette='tab20_r') 30 | pdb_folder = "./data/gen_smi/" 31 | 32 | 33 | data = [] 34 | 35 | 36 | def pdb_to_smiles(pdb_path): 37 | mol = Chem.MolFromPDBFile(pdb_path, removeHs=True) 38 | return Chem.MolToSmiles(mol) if mol else None 39 | 40 | patterns = { 41 | 'STEP': r'STEP:\s*([\d.]+)', 42 | 'SCORE': r'SCORE:\s*([\d.]+)', 43 | 'LE': r'LE:\s*([\d.]+)', 44 | 'sim': r'LE:\s*([\d.]+)', 45 | 'syn1': r'syn1:\s*([\d.]+)', 46 | 'syn2': r'syn2:\s*([\d.]+)', 47 | 'syn3': r'syn3:\s*([\d.]+)', 48 | 'syn4': r'syn4:\s*([\d.]+)', 49 | 'syn5': r'syn5:\s*([\d.]+)', 50 | 'resdidue1': r'resdidue1:\s*([\d.]+)', 51 | 'resdidue2': r'resdidue1:\s*([\d.]+)', 52 | } 53 | 54 | 55 | for filename in os.listdir(pdb_folder): 56 | if filename.endswith('.pdb'): 57 | pdb_path = os.path.join(pdb_folder, filename) 58 | 59 | 60 | with open(pdb_path, 'r') as file: 61 | file_content = file.read() 62 | 63 | tag_values = {} 64 | for tag, pattern in patterns.items(): 65 | match = re.search(pattern, file_content) 66 | tag_values[tag] = float(match.group(1)) if match else None 67 | 68 | 69 | smiles = pdb_to_smiles(pdb_path) 70 | if smiles: 71 | tag_values['SMILES'] = smiles 72 | data.append(tag_values) 73 | 74 | # Convert the data to a DataFrame 75 | df = pd.DataFrame(data) 76 | 77 | 78 | #################################################fig03################### 79 | 80 | df_0=pd.read_csv('./data/gen_smi/syn_rock.csv') 81 | df_1=pd.read_csv('./data/gen_smi/bbar_rock.csv') 82 | smi_ckg=df['SMILES'][0:9999] 83 | smi_ckgi=df['SMILES'][10000:19999] 84 | rock1=pd.read_csv('./data/gen_smi/rock1.csv') 85 | 86 | 87 | 88 | 89 | df1_0=pd.read_csv('./data/gen_smi/syn_sars.csv') 90 | df1_1=pd.read_csv('./data/gen_smi/bbar_sars.csv') 91 | smi_ckg_1=df['SMILES'][20000:29999] 92 | smi_ckgi_1=df['SMILES'][30000:39999] 93 | sars=pd.read_csv('./data/gen_smi/sars.csv') 94 | 95 | 96 | 97 | df2_0=pd.read_csv('./data/gen_smi/syn_aa2ar.csv') 98 | df2_1=pd.read_csv('./data/gen_smi/bbar_aa2arcsv') 99 | smi_ckg_2=df['SMILES'][40000:49999] 100 | smi_ckgi_2=df['SMILES'][50000:59999] 101 | aa2ar=pd.read_csv('./data/gen_smi/aa2ar.csv') 102 | 103 | 104 | 105 | sc_scorer = SCScorer() 106 | sc_scorer.restore() 107 | 108 | 109 | syba_model = Syba() 110 | syba_model.fit_default() 111 | 112 | 113 | ra_model = joblib.load("RAscore_model.pkl") 114 | 115 | 116 | smiles_lists = { 117 | "synnet": df1_0['SMILES'], 118 | "bbar": df1_1['SMILES'], 119 | "click_gen": smi_ckg_1, 120 | "click_gen_inpainting": smi_ckgi_1 121 | } 122 | 123 | 124 | results = [] 125 | 126 | 127 | for list_name, smiles_list in smiles_lists.items(): 128 | for smiles in smiles_list: 129 | mol = Chem.MolFromSmiles(smiles) 130 | 131 | 132 | sc_score = sc_scorer.get_score_from_smi(smiles) if mol else np.nan 133 | 134 | 135 | ra_score = ra_model.predict([mol]) if mol else np.nan 136 | 137 | 138 | syba_score = syba_model.predict(mol) if mol else np.nan 139 | 140 | 141 | try: 142 | gasa_score = GASA(smiles)[0] 143 | except Exception as e: 144 | print(f"GASA : {smiles},error: {e}") 145 | gasa_score = np.nan 146 | 147 | 148 | results.append({ 149 | "List": list_name, 150 | "SMILES": smiles, 151 | "SC-SCORE": sc_score, 152 | "RA-SCORE": ra_score, 153 | "SYBA": syba_score, 154 | "GASA": gasa_score 155 | }) 156 | 157 | df = pd.DataFrame(results) 158 | 159 | 160 | 161 | plt.subplot(2, 2, 1) 162 | sns.kdeplot(data=df, x="SC-SCORE", hue="List") 163 | plt.title("SC-SCORE Distribution (KDE)") 164 | plt.xlabel("SC-SCORE") 165 | plt.ylabel("Density") 166 | 167 | # RA-SCORE - Percentage Distribution 168 | plt.subplot(2, 2, 2) 169 | ra_data = df.groupby("List")["RA-SCORE"].value_counts(normalize=True).unstack().fillna(0) * 100 170 | ra_data.plot(kind="bar", stacked=True, ax=plt.gca()) 171 | plt.title("RA-SCORE Percentage Distribution") 172 | plt.xlabel("List") 173 | plt.ylabel("Percentage (%)") 174 | 175 | # GASA - Percentage Distribution 176 | plt.subplot(2, 2, 3) 177 | gasa_data = df.groupby("List")["GASA"].value_counts(normalize=True).unstack().fillna(0) * 100 178 | gasa_data.plot(kind="bar", stacked=True, ax=plt.gca()) 179 | plt.title("GASA Percentage Distribution") 180 | plt.xlabel("List") 181 | plt.ylabel("Percentage (%)") 182 | 183 | # SYBA - KDE Plot 184 | plt.subplot(2, 2, 4) 185 | sns.kdeplot(data=df, x="SYBA", hue="List") 186 | plt.title("SYBA Distribution (KDE)") 187 | plt.xlabel("SYBA") 188 | plt.ylabel("Density") 189 | 190 | 191 | plt.tight_layout() 192 | plt.savefig('figure3.pdf') 193 | plt.show() 194 | 195 | 196 | 197 | 198 | #################################################fig04################### 199 | 200 | 201 | def smiles_to_ecfp6(smiles_list): 202 | ecfp6_fingerprints = [] 203 | for smiles in smiles_list: 204 | mol = Chem.MolFromSmiles(smiles) 205 | if mol: 206 | fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=1024) 207 | ecfp6_fingerprints.append(fp) 208 | return ecfp6_fingerprints 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | df_0 = pd.read_csv('./data/gen_smi/syn_rock.csv') 221 | df_1 = pd.read_csv('./data/gen_smi/bbar_rock.csv') 222 | smi_ckg = df['SMILES'][0:9999] 223 | smi_ckgi = df['SMILES'][10000:19999] 224 | 225 | rock1 = pd.read_csv('./data/gen_smi/rock1.csv') 226 | 227 | df1_0 = pd.read_csv('./data/gen_smi/syn_sars.csv') 228 | df1_1 = pd.read_csv('./data/gen_smi/bbar_sars.csv') 229 | smi_ckg_1 = df['SMILES'][20000:29999] 230 | smi_ckgi_1 = df['SMILES'][30000:39999] 231 | sars = pd.read_csv('./data/gen_smi/sars.csv') 232 | 233 | df2_0 = pd.read_csv('./data/gen_smi/syn_aa2ar.csv') 234 | df2_1 = pd.read_csv('./data/gen_smi/bbar_aa2ar.csv') 235 | smi_ckg_2 = df['SMILES'][40000:49999] 236 | smi_ckgi_2 = df['SMILES'][50000:59999] 237 | aa2ar = pd.read_csv('./data/gen_smi/aa2ar.csv') 238 | 239 | 240 | 241 | rock1_ecfp6 = smiles_to_ecfp6(rock1['smiles']) 242 | sars_ecfp6 = smiles_to_ecfp6(sars['smiles']) 243 | aa2ar_ecfp6 = smiles_to_ecfp6(aa2ar['smiles']) 244 | 245 | df_0_ecfp6 = smiles_to_ecfp6(df_0['smiles']) 246 | df_1_ecfp6 = smiles_to_ecfp6(df_1['smiles']) 247 | smi_ckg_ecfp6 = smiles_to_ecfp6(smi_ckg) 248 | smi_ckgi_ecfp6 = smiles_to_ecfp6(smi_ckgi) 249 | 250 | 251 | 252 | def tsne_transform(fingerprints): 253 | tsne = TSNE(n_components=2, random_state=42) 254 | return tsne.fit_transform(fingerprints) 255 | 256 | 257 | 258 | datasets = [ 259 | ("rock1", rock1_ecfp6, df_0_ecfp6, df_1_ecfp6, smi_ckg_ecfp6, smi_ckgi_ecfp6), 260 | ("sars", sars_ecfp6, df_0_ecfp6, df_1_ecfp6, smi_ckg_ecfp6, smi_ckgi_ecfp6), 261 | ("aa2ar", aa2ar_ecfp6, df_0_ecfp6, df_1_ecfp6, smi_ckg_ecfp6, smi_ckgi_ecfp6) 262 | ] 263 | 264 | # Set up subplots 265 | fig, axes = plt.subplots(3, 4, figsize=(20, 15)) 266 | fig.suptitle("t-SNE Distribution of Different Datasets", fontsize=16) 267 | 268 | 269 | 270 | for i, (target_name, target_data, df_0_data, df_1_data, smi_ckg_data, smi_ckgi_data) in enumerate(datasets): 271 | 272 | 273 | combined_data = target_data + df_0_data + df_1_data + smi_ckg_data + smi_ckgi_data 274 | tsne_result = tsne_transform(combined_data) 275 | 276 | 277 | 278 | num_target = len(target_data) 279 | num_df_0 = len(df_0_data) 280 | num_df_1 = len(df_1_data) 281 | num_smi_ckg = len(smi_ckg_data) 282 | num_smi_ckgi = len(smi_ckgi_data) 283 | 284 | 285 | 286 | axes[i, 0].scatter(tsne_result[:num_target, 0], tsne_result[:num_target, 1], label=target_name, alpha=0.6) 287 | axes[i, 0].scatter(tsne_result[num_target:num_target + num_df_0, 0], tsne_result[num_target:num_target + num_df_0, 1], alpha=0.6) 288 | axes[i, 1].scatter(tsne_result[num_target:num_target + num_df_1, 0], tsne_result[num_target:num_target + num_df_1, 1], alpha=0.6) 289 | axes[i, 2].scatter(tsne_result[num_target:num_target + num_smi_ckg, 0], tsne_result[num_target:num_target + num_smi_ckg, 1], alpha=0.6) 290 | axes[i, 3].scatter(tsne_result[num_target:num_target + num_smi_ckgi, 0], tsne_result[num_target:num_target + num_smi_ckgi, 1], alpha=0.6) 291 | 292 | for j in range(4): 293 | axes[i, j].set_title(f"{target_name} with Collection {j+1}") 294 | axes[i, j].legend() 295 | 296 | plt.savefig('figure04.pdf') 297 | plt.tight_layout() 298 | plt.show() 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | #################################################fig05################### 307 | 308 | 309 | 310 | 311 | 312 | # Function to calculate molecular weight from syn1-syn5 313 | def calculate_molecular_weight(row): 314 | return row['syn1'] + row['syn2'] + row['syn3'] + row['syn4'] + row['syn5'] 315 | 316 | # Calculate molecular weight for each entry 317 | df['MolecularWeight'] = df.apply(calculate_molecular_weight, axis=1) 318 | 319 | # Split the data for each target 320 | targets = { 321 | "ROCK": df.iloc[:20000], 322 | "SARS": df.iloc[20000:40000], 323 | "AA2AR": df.iloc[40000:] 324 | } 325 | 326 | # Plotting 327 | for target_name, target_data in targets.items(): 328 | # Split target data into ckg and ckg_i 329 | ckg = target_data.iloc[:10000] 330 | ckg_i = target_data.iloc[10000:] 331 | 332 | # First plot: STEP distribution for ckg and ckg_i 333 | plt.figure(figsize=(12, 6)) 334 | sns.kdeplot(ckg['STEP'], label=f'{target_name} - ckg STEP', fill=True) 335 | sns.kdeplot(ckg_i['STEP'], label=f'{target_name} - ckg_i STEP', fill=True) 336 | plt.title(f'STEP Distribution for {target_name}') 337 | plt.xlabel('STEP') 338 | plt.ylabel('Density') 339 | plt.legend() 340 | plt.show() 341 | 342 | # Second plot: Molecular weight distribution for syn1 to syn5 343 | plt.figure(figsize=(12, 6)) 344 | sns.kdeplot(ckg['MolecularWeight'], label=f'{target_name} - ckg Molecular Weight', fill=True) 345 | sns.kdeplot(ckg_i['MolecularWeight'], label=f'{target_name} - ckg_i Molecular Weight', fill=True) 346 | plt.title(f'Molecular Weight Distribution (syn1-syn5) for {target_name}') 347 | plt.xlabel('Molecular Weight') 348 | plt.ylabel('Density') 349 | plt.legend() 350 | plt.savefig('figure05.pdf') 351 | plt.show() 352 | 353 | #################################################fig06################### 354 | 355 | def smiles_to_ecfp6(smiles): 356 | mol = Chem.MolFromSmiles(smiles) 357 | return AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=1024) if mol else None 358 | 359 | # Generate ECFP6 fingerprints for all SMILES entries 360 | df['Fingerprint'] = df['SMILES'].apply(smiles_to_ecfp6) 361 | 362 | # Split data into three targets: ROCK, SARS, and AA2AR 363 | targets = { 364 | "ROCK": df.iloc[:20000], 365 | "SARS": df.iloc[20000:40000], 366 | "AA2AR": df.iloc[40000:] 367 | } 368 | 369 | # Function to apply t-SNE on fingerprints 370 | def apply_tsne(fingerprints): 371 | tsne = TSNE(n_components=2, random_state=42) 372 | return tsne.fit_transform(fingerprints) 373 | 374 | # Plot for each target 375 | for target_name, target_data in targets.items(): 376 | # Split target data into ckg and ckg_i 377 | ckg = target_data.iloc[:10000] 378 | ckg_i = target_data.iloc[10000:] 379 | 380 | # First Plot: t-SNE Distribution colored by SCORE 381 | fingerprints_combined = ckg['Fingerprint'].tolist() + ckg_i['Fingerprint'].tolist() 382 | tsne_result = apply_tsne([list(fp) for fp in fingerprints_combined]) 383 | 384 | plt.figure(figsize=(12, 6)) 385 | plt.scatter(tsne_result[:10000, 0], tsne_result[:10000, 1], c=ckg['SCORE'], cmap='viridis', label='ckg', alpha=0.7) 386 | plt.scatter(tsne_result[10000:, 0], tsne_result[10000:, 1], c=ckg_i['SCORE'], cmap='plasma', label='ckg_i', alpha=0.7) 387 | plt.colorbar(label='SCORE') 388 | plt.title(f't-SNE ECFP6 Distribution of {target_name} with Score Gradient') 389 | plt.xlabel("t-SNE Dimension 1") 390 | plt.ylabel("t-SNE Dimension 2") 391 | plt.legend() 392 | plt.show() 393 | 394 | # Second Plot: LE distribution for ckg and ckg_i 395 | plt.figure(figsize=(12, 6)) 396 | sns.kdeplot(ckg['LE'], label='ckg LE', fill=True) 397 | sns.kdeplot(ckg_i['LE'], label='ckg_i LE', fill=True) 398 | plt.title(f'LE Distribution for {target_name}') 399 | plt.xlabel('LE') 400 | plt.ylabel('Density') 401 | plt.legend() 402 | plt.show() 403 | 404 | # Third Plot: Jointplot of SCORE vs SIM 405 | plt.figure(figsize=(8, 8)) 406 | sns.jointplot(data=target_data, x='SCORE', y='SIM', kind='scatter', palette='viridis', alpha=0.7) 407 | plt.suptitle(f'SCORE vs SIM Jointplot for {target_name}', y=1.02) 408 | plt.xlabel('SCORE') 409 | plt.ylabel('SIM') 410 | plt.show() 411 | 412 | #################################################fig07################### 413 | 414 | sars_df = df.iloc[20000:40000] 415 | smiles_list = sars_df['SMILES'].tolist() 416 | 417 | 418 | v = Vina(sf_name='vina') 419 | 420 | 421 | receptor_path = './data/protein.pdbqt' 422 | v.set_receptor(receptor_path) 423 | 424 | center = [, , ] 425 | box_size = [80, 80, 80] 426 | 427 | # Prepare output directories 428 | os.makedirs('docked_conformations', exist_ok=True) 429 | os.makedirs('original_conformations', exist_ok=True) 430 | 431 | # Function to convert SMILES to 3D conformations and prepare for docking 432 | def smiles_to_3d_pdbqt(smiles, name): 433 | mol = Chem.MolFromSmiles(smiles) 434 | if mol: 435 | mol = Chem.AddHs(mol) 436 | AllChem.EmbedMolecule(mol, randomSeed=42) 437 | AllChem.UFFOptimizeMolecule(mol) 438 | pdb_path = f'original_conformations/{name}.pdb' 439 | pdbqt_path = f'docked_conformations/{name}.pdbqt' 440 | Chem.MolToPDBFile(mol, pdb_path) 441 | Chem.MolToMolFile(mol, pdbqt_path) 442 | return pdb_path, pdbqt_path 443 | return None, None 444 | 445 | 446 | rmsd_values_model1 = [] 447 | rmsd_values_model2 = [] 448 | 449 | for i, smiles in enumerate(smiles_list): 450 | pdb_path, pdbqt_path = smiles_to_3d_pdbqt(smiles, f'ligand_{i}') 451 | if pdb_path and pdbqt_path: 452 | 453 | v.set_ligand_from_file(pdbqt_path) 454 | 455 | 456 | v.compute_vina_maps(center=center, box_size=box_size) 457 | 458 | 459 | v.dock(exhaustiveness=8, n_poses=1) 460 | v.write_poses(f'docked_conformations/ligand_{i}_docked.pdbqt', n_poses=1, overwrite=True) 461 | 462 | 463 | original_mol = Chem.MolFromPDBFile(pdb_path) 464 | docked_mol = Chem.MolFromPDBFile(f'docked_conformations/ligand_{i}_docked.pdbqt') 465 | 466 | 467 | if original_mol and docked_mol: 468 | rmsd = AllChem.GetBestRMS(original_mol, docked_mol) 469 | if i < 10000: 470 | rmsd_values_model1.append(rmsd) 471 | else: 472 | rmsd_values_model2.append(rmsd) 473 | 474 | # Plot RMSD Distributions for two models 475 | plt.figure(figsize=(12, 6)) 476 | sns.kdeplot(rmsd_values_model1, label='ckg RMSD', fill=True) 477 | sns.kdeplot(rmsd_values_model2, label='ckgi RMSD', fill=True) 478 | plt.title('RMSD Distribution for Two Models in SARS Target') 479 | plt.xlabel('RMSD') 480 | plt.ylabel('Density') 481 | 482 | plt.savefig('figure07_1.pdf') 483 | 484 | plt.legend() 485 | plt.show() 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | sars_df = df.iloc[20000:40000] 496 | 497 | # Separate data into two models 498 | ckg = sars_df.iloc[:10000] 499 | ckg_i = sars_df.iloc[10000:] 500 | 501 | # Calculate the proportion of '1's for residue1 and residue2 in each model 502 | proportion_residue1_ckg = ckg['residue1'].mean() # Mean of 1s gives the proportion in binary data 503 | proportion_residue2_ckg = ckg['residue2'].mean() 504 | proportion_residue1_ckg_i = ckg_i['residue1'].mean() 505 | proportion_residue2_ckg_i = ckg_i['residue2'].mean() 506 | 507 | # Create a DataFrame for easy plotting 508 | proportions_df = pd.DataFrame({ 509 | 'Model': ['ckg', 'ckg', 'ckg_i', 'ckg_i'], 510 | 'Residue': ['residue1', 'residue2', 'residue1', 'residue2'], 511 | 'Proportion': [proportion_residue1_ckg, proportion_residue2_ckg, 512 | proportion_residue1_ckg_i, proportion_residue2_ckg_i] 513 | }) 514 | 515 | # Plot the proportions 516 | plt.figure(figsize=(10, 6)) 517 | sns.barplot(data=proportions_df, x='Residue', y='Proportion', hue='Model') 518 | plt.title('Proportion of his41 and cys145 for SARS Target Models') 519 | plt.xlabel('Residue') 520 | plt.ylabel('Proportion') 521 | plt.legend(title='Model') 522 | plt.savefig('figure07_2.pdf') 523 | plt.show() 524 | -------------------------------------------------------------------------------- /images/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mywang1994/cligen_gen/817abe972a3fccb2b9a42885b7890308ec65e2dd/images/figure.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from utils import * 6 | import torch.utils.data 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import pandas as pd 11 | 12 | 13 | #Reaction-based combiner 14 | class combiner(nn.Module): 15 | def __init__(self, vocab_size, embedding_dim, hidden_dim, synthon_hidden_dim): 16 | super(combiner, self).__init__() 17 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 18 | self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True) 19 | self.fc1 = nn.Linear(hidden_dim, synthon_hidden_dim) 20 | self.fc2 = nn.Linear(synthon_hidden_dim, synthon_hidden_dim) 21 | self.fc3 = nn.Linear(synthon_hidden_dim, 1) 22 | self.relu = nn.ReLU() 23 | self.sigmoid = nn.Sigmoid() 24 | 25 | def forward(self, x): 26 | x = self.embedding(x) 27 | _, h_n = self.rnn(x) 28 | x = h_n[-1] 29 | x = self.relu(self.fc1(x)) 30 | x = self.relu(self.fc2(x)) 31 | x = self.sigmoid(self.fc3(x)) 32 | return x 33 | 34 | class pre_dataset_combiner(Dataset): 35 | def __init__(self, smiles, labels, char_to_idx): 36 | self.smiles = smiles 37 | self.labels = labels 38 | self.char_to_idx = char_to_idx 39 | 40 | def __len__(self): 41 | return len(self.smiles) 42 | 43 | def __getitem__(self, idx): 44 | smiles_seq = self.smiles_to_seq(self.smiles[idx]) 45 | padded_seq = self.pad_sequence(smiles_seq) 46 | label = self.labels[idx] 47 | return torch.tensor(padded_seq, dtype=torch.long), torch.tensor(label, dtype=torch.float) 48 | 49 | def smiles_to_seq(self, smile): 50 | return [self.char_to_idx[char] for char in smile] 51 | 52 | def pad_sequence(self, seq): 53 | seq += [0] * (100 - len(seq)) 54 | return seq[:100] 55 | 56 | 57 | #inpainting generator 58 | class inpainting(nn.Module): 59 | def __init__(self, vocab_size=39, embedding_dim=64, hidden_dim=512, de_hidden_dim=128,latent_dim=512,device=0, skip=[0,1,2,3,4], attention=[0,1,2,3,4] ): 60 | super(inpainting, self).__init__() 61 | self.device=device 62 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 63 | 64 | self.skip = skip 65 | self.attention = attention 66 | self.CA = ContextualAttention(ksize=3, stride=1, rate=2, softmax_scale=10, two_input=False, use_cuda=True, device_ids=device) 67 | self.encoder_stage1_conv1 = make_layers(100, 64, kernel_size=1, stride=1, padding=0, bias=False, norm=False, activation=True, is_relu=False) 68 | self.encoder_stage1_conv2 = make_layers(64, 128, kernel_size=1, stride=1, padding=0, bias=False, norm=False, activation=True, is_relu=False) 69 | self.encoder_stage2 = nn.Sequential(convolutional_block([128, 64, 64, 256], norm=False),identity_block([256, 64, 64, 256], norm=False),identity_block([256, 64, 64, 256], norm=False)) 70 | self.encoder_stage3 = nn.Sequential(convolutional_block([256, 128, 128, 512]),identity_block([512, 128, 128, 512]),identity_block([512, 128, 128, 512]),identity_block([512, 128, 128, 512])) 71 | self.encoder_stage4 = nn.Sequential(convolutional_block([512, 256, 256, 1024]),identity_block([1024, 256, 256, 1024]),identity_block([1024, 256, 256, 1024]),identity_block([1024, 256, 256, 1024])) 72 | self.encoder_stage5 = nn.Sequential(convolutional_block([1024, 512, 512, 1024]),identity_block([1024, 512, 512, 1024]),identity_block([1024, 512, 512, 1024]),identity_block([1024, 512, 512, 1024]),identity_block([1024, 512, 512, 1024]) 73 | ) 74 | self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True) 75 | self.rnn1 = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) 76 | self.hidden_to_mean = nn.Linear(hidden_dim, latent_dim) 77 | self.hidden_to_logvar = nn.Linear(hidden_dim, latent_dim) 78 | self.fc0 = nn.Linear(1024 * 512, latent_dim) 79 | self.latent_to_hidden = nn.Linear(latent_dim, de_hidden_dim) 80 | self.deconv1 = make_layers_transpose(latent_dim, 256, kernel_size=4, stride=2, padding=1) 81 | self.deconv2 = make_layers_transpose(256, 128, kernel_size=4, stride=2, padding=1) 82 | self.deconv3 = make_layers_transpose(128, 64, kernel_size=4, stride=2, padding=1) 83 | self.deconv4 = make_layers_transpose(64, embedding_dim, kernel_size=4, stride=2, padding=1) 84 | self.lstm = nn.LSTM(embedding_dim, de_hidden_dim, batch_first=True) 85 | self.hidden_to_vocab = nn.Linear(de_hidden_dim, vocab_size) 86 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 87 | self.reduction = nn.Linear(1024, de_hidden_dim) 88 | self.de_hidden_dim = de_hidden_dim 89 | 90 | # Fully connected layers 91 | self.fc1 = nn.Linear(32 * 8192, 1024) # Adjust the input size to match the flattened output size 92 | self.fc2 = nn.Linear(1024, 256) 93 | self.fc3 = nn.Linear(256, vocab_size) 94 | self.BCT = BCT_P(device=device) 95 | self.decoder_stage2_conv1 = make_layers(1024, 512, kernel_size=1, stride=1, padding=0, bias=False, norm=False, activation=True, is_relu=False) 96 | self.decoder_stage2_conv2 = make_layers(512, 1024, kernel_size=1, stride=1, padding=0, bias=False, norm=False, activation=True, is_relu=False) 97 | 98 | self.feature_out = make_layers(512, 1024, kernel_size=1, stride=1, padding=0, norm=False, activation=True, is_relu=False) 99 | 100 | self.GRB5 = GRB(1024,1) 101 | self.decoder_stage5 = nn.Sequential(identity_block([1024, 512, 512, 1024], is_relu=True),identity_block([1024, 512, 512, 1024], is_relu=True),make_layers_transpose(1024, 1024, kernel_size=4, stride=2, padding=1, bias=False, norm=True, activation=True, is_relu=True) 102 | ) 103 | 104 | self.linear = nn.Linear(1024, 512) 105 | self.SHC4 = SHC(1024) 106 | if 4 in self.skip: 107 | self.SHC4_mid = SHC(1024) 108 | self.skip4 = nn.Sequential( 109 | nn.InstanceNorm1d(1024, affine=True), 110 | nn.ReLU() 111 | ) 112 | self.GRB4 = GRB(1024,2) 113 | self.decoder_stage4 = nn.Sequential(identity_block([1024, 256, 256, 1024], is_relu=True),identity_block([1024, 256, 256, 1024], is_relu=True), identity_block([1024, 256, 256, 1024], is_relu=True), make_layers_transpose(1024, 512, kernel_size=4, stride=2, padding=1, bias=False, norm=True, activation=True, is_relu=True) 114 | ) 115 | self.SHC3 = SHC(512) 116 | if 3 in self.skip: 117 | self.SHC3_mid = SHC(512) 118 | self.skip3 = nn.Sequential( 119 | nn.InstanceNorm1d(512, affine=True), 120 | nn.ReLU() 121 | ) 122 | self.GRB3 = GRB(512,4) 123 | self.decoder_stage3 = nn.Sequential( 124 | identity_block([512, 128, 128, 512], is_relu=True), 125 | identity_block([512, 128, 128, 512], is_relu=True), 126 | identity_block([512, 128, 128, 512], is_relu=True), 127 | make_layers_transpose(512, 256, kernel_size=4, stride=2, padding=1, bias=False, norm=True, activation=True, is_relu=True) 128 | ) 129 | 130 | self.SHC2 = SHC(256, norm=False) 131 | if 2 in self.skip: 132 | self.SHC2_mid = SHC(256, norm=False) 133 | self.skip2 = nn.ReLU() 134 | self.GRB2 = GRB(256, 4, norm=False) 135 | self.decoder_stage2 = nn.Sequential( 136 | identity_block([256, 64, 64, 256], is_relu=True, norm=False), 137 | identity_block([256, 64, 64, 256], is_relu=True, norm=False), 138 | identity_block([256, 64, 64, 256], is_relu=True, norm=False), 139 | identity_block([256, 64, 64, 256], is_relu=True, norm=False), 140 | make_layers_transpose(256, 128, kernel_size=4, stride=2, padding=1, bias=False, norm=False, activation=True, is_relu=True) 141 | ) 142 | 143 | self.SHC1 = SHC(128, norm=False) 144 | if 1 in self.skip: 145 | self.SHC1_mid = SHC(128, norm=False) 146 | self.skip1 = nn.ReLU() 147 | self.decoder_stage1 = make_layers_transpose(128, 64, kernel_size=4, stride=2, padding=1, bias=False, norm=False, activation=True, is_relu=True) 148 | 149 | self.SHC0 = SHC(64, norm=False) 150 | if 0 in self.skip: 151 | self.SHC0_mid = SHC(64, norm=False) 152 | self.skip0 = nn.ReLU() 153 | self.decoder_stage0 = nn.Sequential( 154 | nn.ConvTranspose1d(64, 1024, kernel_size=3, stride=2, padding=1, bias=False), 155 | nn.Sigmoid() 156 | ) 157 | 158 | def encode(self, x): 159 | shortcut = [] 160 | 161 | 162 | x = self.embedding(x) 163 | 164 | 165 | x = self.encoder_stage1_conv1(x) 166 | 167 | shortcut.append(x) 168 | x = self.encoder_stage1_conv2(x) 169 | 170 | shortcut.append(x) 171 | x = self.encoder_stage2(x) 172 | 173 | 174 | shortcut.append(x) 175 | x = self.encoder_stage3(x) 176 | 177 | shortcut.append(x) 178 | x = self.encoder_stage4(x) 179 | 180 | shortcut.append(x) 181 | x, h_n= self.rnn(x) 182 | 183 | shortcut.append(x) 184 | 185 | return x, shortcut 186 | def decode_smi(self, x, shortcut): 187 | 188 | out = self.GRB5(x) 189 | 190 | out = self.decoder_stage5(out) 191 | if 4 in self.skip: 192 | out = torch.split(out, 32, dim=2) 193 | out = list(out) 194 | if (4 in self.attention): 195 | sc_l = [shortcut[4][0]] 196 | 197 | sc_r = [shortcut[4][1]] 198 | 199 | 200 | sc_m = self.CA(out[0], out[2], out[1], [sc_l, sc_r]) 201 | out[1] = self.skip4(self.SHC4_mid(torch.cat((out[1],sc_m[0]),2), out[1])) 202 | out[0] = self.skip4(self.SHC4(torch.cat((out[0],shortcut[4][0]),2), shortcut[4][0])) 203 | out[2] = self.skip4(self.SHC4(torch.cat((out[2],shortcut[4][1]),2), shortcut[4][1])) 204 | out = torch.cat((out),2) 205 | 206 | out = self.GRB4(out) 207 | out = self.decoder_stage4(out) 208 | 209 | if 3 in self.skip: 210 | out = list(torch.split(out, 64,dim=2)) 211 | if (3 in self.attention): 212 | sc_l = [shortcut[3][0]] 213 | sc_r = [shortcut[3][1]] 214 | sc_m = self.CA(out[0], out[2], out[1], [sc_l, sc_r]) 215 | out[1] = self.skip3(self.SHC3_mid(torch.cat((out[1],sc_m[0]),2), out[1])) 216 | out[0] = self.skip3(self.SHC3(torch.cat((out[0],shortcut[3][0]),2), shortcut[3][0])) 217 | out[2] = self.skip3(self.SHC3(torch.cat((out[2],shortcut[3][1]),2), shortcut[3][1])) 218 | out = torch.cat((out),2) 219 | out = self.GRB3(out) 220 | out = self.decoder_stage3(out) 221 | 222 | if 2 in self.skip: 223 | out = list(torch.split(out, 128,dim=2)) 224 | if (2 in self.attention): 225 | sc_l = [shortcut[2][0]] 226 | sc_r = [shortcut[2][1]] 227 | sc_m = self.CA(out[0], out[2], out[1], [sc_l, sc_r]) 228 | out[1] = self.skip2(self.SHC2_mid(torch.cat((out[1],sc_m[0]),2), out[1])) 229 | out[0] = self.skip2(self.SHC2(torch.cat((out[0],shortcut[2][0]),2), shortcut[2][0])) 230 | out[2] = self.skip2(self.SHC2(torch.cat((out[2],shortcut[2][1]),2), shortcut[2][1])) 231 | out = torch.cat((out),2) 232 | out = self.GRB2(out) 233 | out = self.decoder_stage2(out) 234 | 235 | if 1 in self.skip: 236 | out = list(torch.split(out, 256,dim=2)) 237 | if (1 in self.attention): 238 | sc_l = [shortcut[1][0]] 239 | sc_r = [shortcut[1][1]] 240 | sc_m = self.CA(out[0], out[2], out[1], [sc_l, sc_r]) 241 | out[1] = self.skip1(self.SHC1_mid(torch.cat((out[1],sc_m[0]),2), out[1])) 242 | out[0] = self.skip1(self.SHC1(torch.cat((out[0],shortcut[1][0]),2), shortcut[1][0])) 243 | out[2] = self.skip1(self.SHC1(torch.cat((out[2],shortcut[1][1]),2), shortcut[1][1])) 244 | out = torch.cat((out),2) 245 | out = self.decoder_stage1(out) 246 | 247 | if 0 in self.skip: 248 | out = list(torch.split(out, 512,dim=2)) 249 | if (0 in self.attention): 250 | sc_l = [shortcut[0][0]] 251 | sc_r = [shortcut[0][1]] 252 | sc_m = self.CA(out[0], out[2], out[1], [sc_l, sc_r]) 253 | out[1] = self.skip0(self.SHC0_mid(torch.cat((out[1],sc_m[0]),2), out[1])) 254 | 255 | out[0] = self.skip0(self.SHC0(torch.cat((out[0],shortcut[0][0]),2), shortcut[0][0])) 256 | out[2] = self.skip0(self.SHC0(torch.cat((out[2],shortcut[0][1]),2), shortcut[0][1])) 257 | out = torch.cat((out),2) 258 | out = self.decoder_stage0(out) 259 | 260 | return out 261 | 262 | def forward(self, x1, x2, only_encode=False): 263 | shortcut = [[] for i in range(6)] 264 | x1, shortcut_x1 = self.encode(x1) 265 | for i in range(6): 266 | shortcut[i].append(shortcut_x1[i]) 267 | if only_encode: 268 | return x1 269 | 270 | x2, shortcut_x2= self.encode(x2) 271 | for i in range(6): 272 | shortcut[i].append(shortcut_x2[i]) 273 | 274 | 275 | out, f1, f2 = self.BCT(x1, x2) 276 | 277 | out = shortcut[5][0] + out + shortcut[5][1] 278 | 279 | out = self.decode_smi(out, shortcut) 280 | 281 | 282 | #weight = gaussian_weight(out.size(1),out.size(2)) 283 | #bias = gaussian_bias(out.size(1)) 284 | weight = torch.randn(int(out.size(1)/2),out.size(2)).cuda(self.device) 285 | bias = torch.randn(int(out.size(1)/2)).cuda(self.device) 286 | out = F.linear(out,weight,bias) 287 | out = self.decoder_stage2_conv1(out) 288 | out = self.decoder_stage2_conv2(out) 289 | out = out.reshape((out.size(0), -1)) 290 | out = self.fc0(out) 291 | out = out.unsqueeze(2) 292 | 293 | h = self.deconv1(out) 294 | h = self.deconv2(h) 295 | h = self.deconv3(h) 296 | h = self.deconv4(h) 297 | h = h.permute(0, 2, 1) 298 | h_flat = h.contiguous().view(out.size(0), -1) 299 | 300 | 301 | 302 | h0 = h_flat.view(1, out.size(0), -1) 303 | h0=self.reduction(h0) 304 | c0 = torch.zeros_like(h0) 305 | inputs = torch.zeros(out.size(0), 100, dtype=torch.long, device=self.device) 306 | inputs = self.embedding(inputs) 307 | output, _ = self.lstm(inputs, (h0, c0)) 308 | logits = self.hidden_to_vocab(output) 309 | 310 | return logits, f1, f2 311 | 312 | 313 | 314 | 315 | 316 | class identity_block(nn.Module): 317 | def __init__(self, channels, norm=True, is_relu=False): 318 | super(identity_block, self).__init__() 319 | 320 | self.conv1 = make_layers(channels[0], channels[1], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=True, is_relu=is_relu) 321 | self.conv2 = make_layers(channels[1], channels[2], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=True, is_relu=is_relu) 322 | self.conv3 = make_layers(channels[2], channels[3], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=False) 323 | self.output = nn.ReLU() if is_relu else nn.LeakyReLU(negative_slope=0.2) 324 | 325 | def forward(self,x): 326 | shortcut = x 327 | x = self.conv1(x) 328 | x = self.conv2(x) 329 | x = self.conv3(x) 330 | x = x + shortcut 331 | 332 | x = self.output(x) 333 | return x 334 | 335 | class convolutional_block(nn.Module): 336 | def __init__(self, channels, norm=True, is_relu=False): 337 | super(convolutional_block, self).__init__() 338 | 339 | self.conv1 = make_layers(channels[0], channels[1], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=True, is_relu=is_relu) 340 | self.conv2 = make_layers(channels[1], channels[2], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=True, is_relu=is_relu) 341 | self.conv3 = make_layers(channels[2], channels[3], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=False) 342 | self.shortcut_path = make_layers(channels[0], channels[3], kernel_size=1, stride=1, padding=0, bias=False, norm=norm, activation=False) 343 | self.output = nn.ReLU() if is_relu else nn.LeakyReLU(negative_slope=0.2) 344 | 345 | def forward(self,x): 346 | shortcut = x 347 | x = self.conv1(x) 348 | x = self.conv2(x) 349 | x = self.conv3(x) 350 | shortcut = self.shortcut_path(shortcut) 351 | x = x + shortcut 352 | x = self.output(x) 353 | return x 354 | 355 | class SHC(nn.Module): 356 | def __init__(self, channel, norm=True): 357 | super(SHC, self).__init__() 358 | 359 | self.conv1 = make_layers(channel, int(channel/2), kernel_size=1, stride=1, padding=0, norm=norm, activation=True, is_relu=True) 360 | self.conv2 = make_layers(int(channel/2), int(channel/2), kernel_size=3, stride=1, padding=1, norm=norm, activation=True, is_relu=True) 361 | self.conv3 = make_layers(int(channel/2), channel, kernel_size=1, stride=1, padding=0, norm=norm, activation=False) 362 | 363 | def forward(self, x, shortcut): 364 | x = self.conv1(x) 365 | x = self.conv2(x) 366 | x = self.conv3(x) 367 | x = torch.cat((x, shortcut),2) 368 | return x 369 | 370 | class GRB(nn.Module): 371 | def __init__(self, channel, dilation, norm=True): 372 | super(GRB, self).__init__() 373 | 374 | self.path1 = nn.Sequential( 375 | make_layers(channel, int(channel/2), kernel_size=1, stride=1, padding=0, dilation=dilation, norm=norm, activation=True, is_relu=True), 376 | make_layers(int(channel/2), channel, kernel_size=1, stride=1, padding=0, dilation=dilation, norm=norm, activation=False) 377 | ) 378 | self.path2 = nn.Sequential( 379 | make_layers(channel, int(channel/2), kernel_size=1, stride=1, padding=0, dilation=dilation, norm=norm, activation=True, is_relu=True), 380 | make_layers(int(channel/2), channel, kernel_size=1, stride=1, padding=0, dilation=dilation, norm=norm, activation=False) 381 | ) 382 | self.output = nn.ReLU() 383 | 384 | def forward(self, x): 385 | x1 = self.path1(x) 386 | x2 = self.path2(x) 387 | 388 | x = x + x1 + x2 389 | x = self.output(x) 390 | return x 391 | 392 | class BCT_P(nn.Module): 393 | def __init__(self, size=[512, 4], split=4, pred_step=4, device=0): 394 | super(BCT_P, self).__init__() 395 | 396 | self.channel, self.width = size 397 | self.height = 1 398 | 399 | self.LSTM_encoder_1 = nn.LSTM(512, 512, num_layers=2, batch_first=True) 400 | self.LSTM_decoder_1 = nn.LSTM(512, 512, num_layers=2, batch_first=True) 401 | self.LSTM_decoder_2 = nn.LSTM(512, 1024, num_layers=2, batch_first=True) 402 | 403 | 404 | 405 | self.dec_feat = make_layers(128, 1024, kernel_size=1, stride=1, padding=0, norm=False, activation=True, is_relu=False) 406 | self.split = split 407 | self.pred_step = pred_step 408 | self.device = device 409 | def forward(self, x1, x2): 410 | 411 | batch_size = x1.size(0) 412 | init_hidden = ( 413 | Variable(torch.zeros(2, batch_size, 512)).cuda(self.device), 414 | Variable(torch.zeros(2, batch_size, 512)).cuda(self.device) 415 | ) 416 | init_hidden_1 = ( 417 | Variable(torch.zeros(2, batch_size, 128)).cuda(self.device), 418 | Variable(torch.zeros(2, batch_size, 128)).cuda(self.device) 419 | ) 420 | 421 | # Split the input tensors along the channel dimension 422 | 423 | split_size = self.channel // self.split 424 | 425 | x1_splits = torch.split(x1, split_size, dim=1) 426 | x2_splits = torch.split(x2, split_size, dim=1) 427 | 428 | x1_split_reversed = [split.flip(dims=[1]) for split in x1_splits] 429 | x2_split_reversed = [split.flip(dims=[1]) for split in x2_splits] 430 | # Encode feature from x2 (left->right) 431 | en_hidden = init_hidden 432 | for i in range(self.split): 433 | split_input = x2_splits[i] 434 | 435 | 436 | en_out, en_hidden = self.LSTM_encoder_1(split_input, en_hidden) 437 | hidden_x2 = en_hidden 438 | 439 | 440 | 441 | # Encode feature from x1 (right->left) 442 | en_hidden = init_hidden 443 | for i in reversed(range(self.split)): 444 | split_input = x1_split_reversed[i] 445 | en_out, en_hidden = self.LSTM_encoder_1(split_input, en_hidden) 446 | hidden_x1_reversed = en_hidden 447 | 448 | # Decode feature from x1 (left->right) 449 | de_hidden = init_hidden 450 | for i in range(self.split): 451 | split_input = x1_splits[i] 452 | de_out, de_hidden = self.LSTM_decoder_1(split_input, de_hidden) 453 | 454 | 455 | de_out, de_hidden = self.LSTM_decoder_1(de_out, de_hidden) 456 | de_out=self.dec_feat(de_out) 457 | 458 | x1_out = x1 + de_out 459 | 460 | 461 | 462 | 463 | de_hidden = hidden_x1_reversed 464 | for i in reversed(range(self.split)): 465 | split_input = x2_split_reversed[i] 466 | de_out, de_hidden = self.LSTM_decoder_1(split_input, de_hidden) 467 | de_out=self.dec_feat(de_out) 468 | x2_out = de_out + x2 469 | 470 | 471 | for i in range(self.split): 472 | de_out, de_hidden = self.LSTM_decoder_1(de_out, de_hidden) 473 | x2_out = de_out + x2_out 474 | 475 | out = x1_out + x2_out 476 | 477 | 478 | 479 | return out, x1_out, x2_out 480 | 481 | class ContextualAttention(nn.Module): 482 | def __init__(self, ksize=3, stride=1, rate=1, softmax_scale=10, two_input=True, weight_func='cos', use_cuda=False, device_ids=None): 483 | super(ContextualAttention, self).__init__() 484 | self.ksize = ksize 485 | self.stride = stride 486 | self.rate = rate 487 | 488 | self.softmax_scale = softmax_scale 489 | 490 | self.two_input = two_input 491 | self.use_cuda = use_cuda 492 | self.device_ids = device_ids 493 | if weight_func == 'cos': 494 | self.weight_func = cos_function_weight 495 | elif weight_func == 'gaussian': 496 | self.weight_func = gaussian_weight 497 | 498 | def forward(self, left, right, mid, shortcut, mask=None): 499 | 500 | if self.two_input == False: 501 | 502 | left = torch.cat((left,right),2) 503 | for i in range(len(shortcut[0])): 504 | shortcut[0][i] = torch.cat((shortcut[0][i], shortcut[1][i]),2) 505 | 506 | 507 | raw_int_ls = list(shortcut[0][0].size()) 508 | raw_int_ms = list(shortcut[1][0].size()) 509 | 510 | if self.two_input: 511 | raw_int_rs = list(shortcut[1][0].size()) 512 | 513 | raw_l = [item[0] for item in shortcut] 514 | 515 | 516 | if self.two_input: 517 | raw_r = raw_l = [item[1] for item in shortcut] 518 | 519 | raw_l = [raw_l[i].view(raw_int_ls[0], raw_int_ls[1], -1) for i in range(len(raw_l))] 520 | raw_l_groups = [torch.split(raw_l[i], 1, dim=2) for i in range(len(raw_l))] 521 | 522 | if self.two_input: 523 | raw_r = [raw_r[i].view(raw_int_rs[0], raw_int_rs[1], -1) for i in range(len(raw_r))] 524 | raw_r_groups = [torch.split(raw_r[i], 1, dim=2) for i in range(len(raw_r))] 525 | 526 | 527 | left = F.interpolate(left, scale_factor=1, mode='nearest') 528 | 529 | if self.two_input: 530 | right = F.interpolate(right, scale_factor=1, mode='nearest') 531 | mid = F.interpolate(mid, scale_factor=1, mode='nearest') 532 | int_ls = list(left.size()) 533 | if self.two_input: 534 | int_rs = list(right.size()) 535 | int_mids = list(mid.size()) 536 | mid_groups = torch.split(mid, 2, dim=2) 537 | 538 | 539 | left = left.view(int_ls[0], int_ls[1], -1) 540 | l_groups = torch.split(left, 2, dim=2) 541 | if self.two_input: 542 | right = right.view(int_rs[0], int_rs[1], -1) 543 | r_groups = torch.split(right, 2, dim=2) 544 | batch = [i for i in range(raw_int_ls[0])] 545 | 546 | y_l = [[] for i in range(len(shortcut[0]))] 547 | y_r = [[] for i in range(len(shortcut[0]))] 548 | y = [[] for i in range(len(shortcut[0]))] 549 | 550 | 551 | weight = self.weight_func(raw_int_ls[0], raw_int_ls[2], device=self.device_ids) 552 | scale = self.softmax_scale 553 | 554 | 555 | if self.two_input == False: 556 | r_groups = l_groups 557 | 558 | 559 | 560 | for xi, li, ri, batch_idx in zip(mid_groups, l_groups, r_groups, batch): 561 | 562 | escape_NaN = torch.FloatTensor([1e-4]) 563 | if self.use_cuda: 564 | escape_NaN = escape_NaN.cuda(self.device_ids) 565 | 566 | yi = [] 567 | xi = F.pad(xi, (1, 0)) 568 | yi.append(F.conv1d(xi, li, stride=1)) 569 | 570 | if self.two_input: 571 | yi.append(F.conv1d(xi, ri, stride=1)) 572 | 573 | 574 | yi = [F.softmax(yi[i]*scale, dim=1) for i in range(len(yi))] 575 | 576 | 577 | for i in range(len(shortcut[0])): 578 | li_center = raw_l_groups[i][batch_idx] 579 | 580 | current_dim = yi[0].shape[1] 581 | target_dim = li_center.shape[1] 582 | pad_size = target_dim - current_dim 583 | 584 | 585 | yi[0] = F.pad(yi[0], (0, 0, 0, pad_size), 'constant', 0) 586 | 587 | 588 | if self.two_input: 589 | ri_center = raw_r_groups[i][batch_idx] 590 | y_l[i].append(torch.cat((yi[0], li_center), dim=2)) 591 | if self.two_input: 592 | y_r[i].append(torch.cat((yi[0], ri_center), dim=2)) 593 | 594 | for i in range(len(shortcut[0])): 595 | 596 | 597 | y_l[i] = torch.cat(y_l[i], dim=2).contiguous() 598 | pad_size=raw_int_ms[2]-y_l[i].shape[2] 599 | y_l[i] = F.pad(y_l[i], (0, pad_size), 'constant', 0) if pad_size > 0 else y_l[i] 600 | 601 | 602 | if self.two_input: 603 | y_r[i] = torch.cat(y_r[i], dim=2).contiguous() 604 | y_r[i] = F.pad(y_r[i], (0, pad_size), 'constant', 0) if pad_size > 0 else y_r[i] 605 | else: 606 | y[i]=y_l[i] 607 | 608 | return y 609 | 610 | def collate_fn(batch, vocab_size): 611 | lefts, rights, targets = zip(*batch) 612 | lefts = nn.utils.rnn.pad_sequence(lefts, batch_first=True, padding_value=vocab_size) 613 | rights = nn.utils.rnn.pad_sequence(rights, batch_first=True, padding_value=vocab_size) 614 | targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=vocab_size) 615 | return lefts, rights, targets 616 | 617 | 618 | def reparameterize(mean, logvar): 619 | std = torch.exp(0.5 * logvar) 620 | eps = torch.randn_like(std) 621 | return mean + eps * std 622 | 623 | 624 | 625 | 626 | def smiles_to_indices(smiles, char_to_idx): 627 | return [char_to_idx[char] for char in smiles] 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | from utils import * 9 | from rdkit import Chem, AllChem 10 | import random 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | import os 14 | import numpy as np 15 | from openbabel.pybel import * 16 | import argparse 17 | from model import inpainting, combiner,pre_dataset_combiner 18 | 19 | 20 | 21 | def str2bool(v): 22 | if isinstance(v, bool): 23 | return v 24 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 25 | return True 26 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 27 | return False 28 | else: 29 | raise argparse.ArgumentTypeError('Boolean value expected.') 30 | 31 | 32 | 33 | parser = argparse.ArgumentParser(description='ClickGen: Directed Exploration of Synthesizable Chemical Space Leading to the Rapid Synthesis of Novel and Active Lead Compounds via Modular Reactions and Reinforcement Learning') 34 | 35 | 36 | 37 | parser.add_argument('--syn_p', type=str, default='./data/synthons/synthons.csv', help='the path of synthons library') 38 | parser.add_argument('--input', type=str, help='start fragment', default='[3*]NC1CCCC(N[3*])CC1',required=False) 39 | parser.add_argument('--protein', type=str, default='./data/protein.pdbqt', help='protein, PARP1 PDBQT format') 40 | parser.add_argument('--inpainting', type=str2bool, nargs='?', const=True, default=True, help="Inpainting mode (default: False)") 41 | parser.add_argument('--num_sims', type=int, default=10000, help='Number of simulation steps',required=False) 42 | parser.add_argument('--embed_dim ', type=int, default=64, help='embedding_dim',required=False) 43 | parser.add_argument('--hid_dim ', type=int, default=256, help='hidden_dim',required=False) 44 | parser.add_argument('--syn_dim ', type=int, default=128, help='synthon_hidden_dim',required=False) 45 | 46 | 47 | 48 | 49 | args = parser.parse_args() 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | embedding_dim = args.embed_dim 58 | hidden_dim = args.hid_dim 59 | synthon_hidden_dim = args.syn_dim 60 | syn_path=args.syn_p 61 | char_set=tokened(syn_path) 62 | char_to_idx = {char: idx for idx, char in enumerate(char_set)} 63 | vocab_size = len(char_to_idx) 64 | idx_to_char = {v: k for k, v in char_to_idx.items()} 65 | 66 | 67 | def smi_to_seq(smi, char_to_idx): 68 | smis = smi.replace("Cl", "X").replace("Br", "Y").replace("[nH]", "Z") 69 | sequence = [] 70 | for char in smis: 71 | if char in char_to_idx: 72 | sequence.append(char_to_idx[char]) 73 | else: 74 | print(f"Unrecognized character in SMILES: {char}") 75 | return sequence 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | def try_react(syn1, syn2, reaction_smarts): 84 | mol1 = Chem.MolFromSmiles(syn1) 85 | mol2 = Chem.MolFromSmiles(syn2) 86 | rxn = AllChem.ReactionFromSmarts(reaction_smarts) 87 | products = rxn.RunReactants((mol1, mol2)) 88 | if products: 89 | return Chem.MolToSmiles(products[0][0]) 90 | return None 91 | 92 | def combine_syns(input,num,module): 93 | df = pd.read_csv(syn_path) 94 | sample_syns = random.sample(df['Synthons'].tolist(), min(num, len(df))) 95 | if module: 96 | processed_smi_list = [] 97 | for smi in sample_syns: 98 | l, m, r = process_synthons_dataset(smi) 99 | smi_symbols=assemble_smiles_with_symbols(l, m, r) 100 | processed_smi_list.append(smi_symbols) 101 | sample_syns=processed_smi_list 102 | successful_products = [] 103 | for syn_smis in sample_syns: 104 | product = try_react(input, syn_smis,'[*:1]C(=O)O[3*].[*:2]N[3*]>>[*:1]C(=O)N[*:2]') 105 | if product: 106 | successful_products.append(product) 107 | else: 108 | product = try_react(input, syn_smis,'[1*]/C=C(\[2*])[*:1].[2*]/N=N\N([1*])[*:2]>>[*:1]c1cn([*:2])nn1') 109 | if product: 110 | successful_products.append(product) 111 | else: 112 | product = try_react(input, syn_smis, "[*].[*]>>[*]-[*]") 113 | return successful_products 114 | 115 | def predict_syns(input,module): 116 | module=args.inpainting 117 | if module: 118 | l_m, _m, r_m = process_synthons_dataset(input,char_to_idx,vocab_size) 119 | m_m,_f1,_f2= inpainting(l_m,r_m) 120 | l_smi,m_smi,r_smi=ind2smi(l_m)[0],ind2smi(m_m)[0],ind2smi(r_m)[0] 121 | input = assemble_smiles(l_smi,m_smi,r_smi) 122 | smi_list=combine_syns(input,1000,module) 123 | combine_model = combiner() 124 | combine_model.load_state_dict(torch.load('./data/model/combiner.pth')) 125 | combine_model.eval() 126 | inpainting_model = inpainting(vocab_size=vocab_size) 127 | inpainting_model.load_state_dict(torch.load('./data/model/inpainting.pth')) 128 | inpainting_model.eval() 129 | 130 | test_dataset = pre_dataset_combiner(smi_to_seq(smi_list), [0] * len(smi_to_seq(smi_list)), char_to_idx) 131 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 132 | 133 | smiles_probabilities = [] 134 | with torch.no_grad(): 135 | for i, (inputs, _) in enumerate(test_loader): 136 | if module: 137 | inputs=disassemble_smiles_with_symbols(inputs) 138 | inputs,_f1,_f2=inpainting(inputs[0],inputs[2]) 139 | m_smi=ind2smi(inputs)[0] 140 | inputs = assemble_smiles(inputs[0],m_smi,inputs[0]) 141 | outputs = combine_model(inputs) 142 | probability = outputs.item() 143 | smiles_probabilities.append((smi_list[i], probability)) 144 | return smiles_probabilities 145 | 146 | def smi_to_sdf(smi,out_path): 147 | mol = Chem.MolFromSmiles(smi) 148 | mol = Chem.AddHs(mol) 149 | 150 | AllChem.EmbedMolecule(mol, AllChem.ETKDG()) 151 | AllChem.UFFOptimizeMolecule(mol) 152 | sdf_filename = out_path 153 | writer = Chem.SDWriter(sdf_filename) 154 | writer.write(mol) 155 | writer.close() 156 | 157 | def roulette(select_list): 158 | ''' 159 | roulette algorithm 160 | ''' 161 | sum_val = sum(select_list) 162 | random_val = random.random() 163 | probability = 0 164 | if sum_val != 0: 165 | for i in range(len(select_list)): 166 | probability += select_list[i] / sum_val 167 | if probability >= random_val: 168 | return i 169 | else: 170 | continue 171 | else: 172 | return random.choice(range(len(select_list))) 173 | 174 | 175 | DIC=[] 176 | SOC=[] 177 | N_IDX=0 178 | 179 | class State(): 180 | 181 | def __init__(self, input, cho=None, sta=[], choices=[],start=True): 182 | self.input = input 183 | self.start = start 184 | self.score=0 185 | self.states = sta + [self.score] 186 | self.choices = choices + [self.cho] 187 | 188 | def is_terminal(self,sdf): 189 | suppl = Chem.SDMolSupplier(sdf) 190 | for mol in suppl: 191 | if mol is not None: 192 | non_hydrogen_count = sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() != 1) 193 | if non_hydrogen_count > 1: 194 | return True 195 | else: 196 | return False 197 | def next_state(self): 198 | if self.start: 199 | syn_pro=predict_syns(input,1000) 200 | smiles_scores = [(smiles, vina_dock(smiles)) for smiles in 201 | [item[0] for item in sorted(syn_pro, key=lambda x: x[1], reverse=True)[:10]]] 202 | smi_to_sdf(max(smiles_scores, key=lambda item: item[1])[0],'./log/output.sdf') 203 | 204 | DIC.append(smiles_scores) 205 | else: 206 | input = Chem.MolToSmiles(input,10) 207 | syn_pro=predict_syns(input) 208 | smiles_scores = [(smiles, vina_dock(smiles)) for smiles in 209 | [item[0] for item in sorted(syn_pro, key=lambda x: x[1], reverse=True)[:1]]] 210 | smi_to_sdf(max(smiles_scores, key=lambda item: item[1])[0],'./log/output.sdf') 211 | DIC.append(smiles_scores) 212 | 213 | 214 | class Node(): 215 | def __init__(self, state, parent=None, reward=0): 216 | self.visits = 0 217 | self.reward = reward 218 | self.state = state 219 | self.children = [] 220 | self.parent = parent 221 | self.longest_path = 0 222 | 223 | 224 | def add_child(self, child_state, node_id): 225 | child = Node(child_state, node_id=node_id, parent=self) 226 | self.children.append(child) 227 | 228 | def update(self, reward): 229 | self.reward += reward 230 | self.visits += 1 231 | 232 | def fully_expanded(self, num_moves_lambda): 233 | num_moves = len(DIC) 234 | if num_moves_lambda != None: 235 | num_moves = num_moves_lambda(self) 236 | if len(self.children) == num_moves: 237 | return True 238 | return False 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | def UCTSEARCH(budget, root, start_score=0, num_moves_lambda=None): 249 | # Begin the MCTS 250 | for iter in range(int(budget)): 251 | front = TREEPOLICY(root, start_score, num_moves_lambda) 252 | BACKUP2(front) 253 | 254 | 255 | def TREEPOLICY(node, start_score): 256 | # Choose whether to expand the node based on the status of the current node 257 | while node.state.is_terminal('./log/output.sdf') == False: 258 | if len(node.children) == 0: 259 | node = EXPAND(node, start_score) 260 | else: 261 | node = BESTCHILD(node, start_score) 262 | return node 263 | 264 | 265 | def EXPAND(node, start_score): 266 | 267 | # Get the children of a node and add them to the tree 268 | if node.state: 269 | for nextmove in node.state.h1s_avail: 270 | next = State(state_type=1, sdf=rf'{node.state.sdf}', h1=nextmove, sta=node.state.states) 271 | N_IDX += 1 272 | node.add_child(next, node_id=N_IDX, bestscore=start_score) 273 | return node.children[-1] 274 | else: 275 | new_states = node.state.next_state() 276 | if len(new_states) == 0: 277 | return node 278 | else: 279 | scores = [] 280 | for nextmove in new_states: 281 | next = State(state_type=0, Frag_Deg=nextmove[:2], sco=nextmove[2], sta=node.state.states) 282 | N_IDX += 1 283 | best_score = min(start_score, nextmove[2]) 284 | scores.append(abs(nextmove[4])) 285 | node.add_child(next, node_id=N_IDX, bestscore=best_score, qed=abs(nextmove[4])) 286 | return node.children[roulette(scores)] 287 | 288 | 289 | def BESTCHILD(node, start_score): 290 | 291 | # Select child nodes based on the node's UCB 292 | scores = [] 293 | for c in node.children: 294 | 295 | exploit = start_score - c.best_score 296 | explore = math.sqrt(2.0 * math.log(node.visits + 0.000001) / float(c.visits + 0.000001)) 297 | 298 | score = exploit + 1 / (2 * math.sqrt(2.0)) * explore 299 | scores.append(score) 300 | if True: 301 | idx = roulette(scores) 302 | 303 | else: 304 | idx = random.choice(range(len(scores))) 305 | return node.children[idx] 306 | 307 | 308 | def DEFAULTPOLICY(node): 309 | state = node.state 310 | num_states = 0 311 | 312 | while state.is_terminal('./log/output.sdf') == False: 313 | state = state.next_state() 314 | num_states += 1 315 | if state.type == 1: 316 | if num_states != 0: 317 | num_states -= 1 318 | num_nodes = len(state.states) - num_states 319 | print(state.type) 320 | return state.states, num_nodes, num_states 321 | 322 | 323 | def BACKUP2(node): 324 | 325 | parent_node = node 326 | while parent_node != None: 327 | parent_node.visits += 1 328 | if len(parent_node.children) == 0: 329 | x = parent_node 330 | parent_node = node.parent 331 | son_node = x 332 | else: 333 | if parent_node.best_score > son_node.best_score: 334 | parent_node.best_score = son_node.best_score 335 | x = parent_node 336 | parent_node = parent_node.parent 337 | son_node = x 338 | 339 | 340 | def BACKUP(node, states, num_nodes): 341 | i = 1 342 | if node.longest_path == 0: 343 | node.longest_path = len(states) 344 | while node != None: 345 | node.visits += 1 346 | best_score = min(states[num_nodes - i:]) 347 | i += 1 348 | if best_score < node.best_score: 349 | node.best_score = best_score 350 | reward = max(best_score, 0) 351 | else: 352 | reward = 0 353 | if best_score < np.mean(DIC[:,0]): 354 | SOC.append(best_score) 355 | node.reward += reward 356 | node = node.parent 357 | return 358 | 359 | if __name__ == "__main__": 360 | 361 | 362 | frag = args.input 363 | score = vina_dock(frag) 364 | ipts = [args.input, score] 365 | current_node = Node(State(), reward=ipts[1]) 366 | result = UCTSEARCH(args.num_sims, current_node, start_score=ipts[1]) 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | -------------------------------------------------------------------------------- /train_combiner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset, DataLoader 5 | import random 6 | import pandas as pd 7 | from rdkit import Chem 8 | from rdkit.Chem import AllChem, DataStructs 9 | from utils import * 10 | from model import combiner,pre_dataset_combiner 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Traning the reaction-based combiner..........') 14 | parser.add_argument('--mol_p', type=str, help='the path of molecular dataset', default='./data/synthons/data.csv') 15 | parser.add_argument('--syn_p', type=str, help='the path of synthons library',default='./data/synthons/synthons.csv') 16 | 17 | parser.add_argument('--num_p', type=int, help='number of positive_samples', default=100) 18 | parser.add_argument('--num_n', type=int, help='number of negative_samples',default=1000) 19 | parser.add_argument('--lr',type=float,help='learning rate',default=1e-4) 20 | parser.add_argument('--epoch',type=int,help='training epoches',default=80) 21 | 22 | args = parser.parse_args() 23 | 24 | 25 | #build molecular dataset 26 | df_m = pd.read_csv(args.mol_p) 27 | df_m['SMILES'] = df_m['SMILES'].str.replace('Cl', 'X').replace('Br', 'Y').replace('[nH]', 'Z') 28 | smiles_db = df_m["SMILES"].tolist() 29 | smiles_db=filter_invalid_molecules(smiles_db) 30 | 31 | #build synthon dataset 32 | df_s = pd.read_csv(args.syn_p) 33 | synthon_db = df_s["Synthons"].tolist() 34 | 35 | char_set=tokened(args.mol_p) 36 | char_to_idx = {char: idx for idx, char in enumerate(char_set)} 37 | 38 | 39 | 40 | def contains_functional_group(mol, smarts): 41 | patt = Chem.MolFromSmarts(smarts) 42 | return mol.HasSubstructMatch(patt) 43 | 44 | def decompose_molecule(mol, decomp_rules): 45 | frags = [] 46 | for rule in decomp_rules: 47 | rxn = AllChem.ReactionFromSmarts(rule) 48 | ps = rxn.RunReactants((mol,)) 49 | for products in ps: 50 | frags.append(products) 51 | return frags 52 | 53 | def calc_tanimoto(mol1, mol2): 54 | fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2) 55 | fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2) 56 | return DataStructs.TanimotoSimilarity(fp1, fp2) 57 | 58 | 59 | def apply_reaction(smarts_reaction, reactants): 60 | rxn = AllChem.ReactionFromSmarts(smarts_reaction) 61 | products = rxn.RunReactants(reactants) 62 | return [Chem.MolToSmiles(product[0]) for product in products] 63 | 64 | 65 | #generate positive_samples and negative_samples 66 | def generate_samples(mol, synthon_db, threshold_positive=0.7, threshold_negative=0.4, n_positive=100,n_negative=1000): 67 | positive_samples = [] 68 | negative_samples = [] 69 | for synthon in synthon_db: 70 | synthon_mol = Chem.MolFromSmiles(synthon) 71 | similarity = calc_tanimoto(mol, synthon_mol) 72 | if similarity >= threshold_positive: 73 | positive_samples.append(synthon) 74 | elif similarity <= threshold_negative: 75 | negative_samples.append(synthon) 76 | return random.sample(positive_samples, n_positive), random.sample(negative_samples, n_negative) 77 | 78 | 79 | 80 | 81 | def train_combiner(model, data_loader, criterion, optimizer, device, num_epochs=10): 82 | model.train() 83 | for epoch in range(num_epochs): 84 | running_loss = 0.0 85 | for inputs, labels in data_loader: 86 | inputs, labels = inputs.to(device), labels.to(device) 87 | optimizer.zero_grad() 88 | outputs = model(inputs) 89 | loss = criterion(outputs, labels.unsqueeze(1)) 90 | loss.backward() 91 | optimizer.step() 92 | running_loss += loss.item() 93 | print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(data_loader)}') 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | def main(): 103 | embedding_dim = 128 104 | hidden_dim = 256 105 | synthon_hidden_dim = 128 106 | vocab_size = len(char_to_idx) 107 | 108 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 109 | model = combiner(vocab_size, embedding_dim, hidden_dim, synthon_hidden_dim).to(device) 110 | criterion = nn.BCELoss() 111 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 112 | 113 | decomp_rules_amide = ["[*:1]C(=O)N[*:2]>>[*:1]C(=O)O[3*].[*:2]N[3*]"] 114 | decomp_rules_triazole = ["[*:1]c1cn([*:2])nn1>>[1*]/C=C(\\[2*])[*:1].[2*]/N=N\\N([1*])[*:2]"] 115 | amide_reaction = '[*:1]C(=O)O[3*].[*:2]N[3*]>>[*:1]C(=O)N[*:2]' 116 | triazole_reaction = '[1*]/C=C(\\[2*])[*:1].[2*]/N=N\\N([1*])[*:2]>>[*:1]c1cn([*:2])nn1' 117 | 118 | for e in range(args.epoch): 119 | selected_smiles = random.sample(smiles_db, 10000) 120 | training_smiles = [] 121 | training_labels = [] 122 | for smile in selected_smiles: 123 | mol = Chem.MolFromSmiles(smile) 124 | if contains_functional_group(mol, "C(=O)N") or contains_functional_group(mol, "c1cnnn1"): 125 | fragments = decompose_molecule(mol, decomp_rules_amide + decomp_rules_triazole) 126 | for fragment in fragments: 127 | frag_smile = Chem.MolToSmiles(fragment[0]) 128 | positive_samples, negative_samples = generate_samples(fragment[0], synthon_db) 129 | for sample in positive_samples: 130 | sample_mol = Chem.MolFromSmiles(sample) 131 | if contains_functional_group(fragment[0], "C(=O)O"): 132 | combined_smile = apply_reaction(amide_reaction, (fragment[0], sample_mol)) 133 | elif contains_functional_group(fragment[0], "/C=C/"): 134 | combined_smile = apply_reaction(triazole_reaction, (fragment[0], sample_mol)) 135 | if combined_smile: 136 | training_smiles.extend(combined_smile) 137 | training_labels.extend([1] * len(combined_smile)) 138 | for sample in negative_samples: 139 | sample_mol = Chem.MolFromSmiles(sample) 140 | if contains_functional_group(fragment[0], "C(=O)O"): 141 | combined_smile = apply_reaction(amide_reaction, (fragment[0], sample_mol)) 142 | elif contains_functional_group(fragment[0], "/C=C/"): 143 | combined_smile = apply_reaction(triazole_reaction, (fragment[0], sample_mol)) 144 | if combined_smile: 145 | training_smiles.extend(combined_smile) 146 | training_labels.extend([0] * len(combined_smile)) 147 | 148 | dataset = pre_dataset_combiner(training_smiles, training_labels, char_to_idx) 149 | data_loader = DataLoader(dataset, batch_size=256, shuffle=True) 150 | train_combiner(model, data_loader, criterion, optimizer, device, num_epochs=1) 151 | torch.save(model.state_dict(), './data/combiner.pth') 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /train_inpainting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | from utils import SMILESDataset, tokened, gaussian_weight 8 | import argparse 9 | import pandas as pd 10 | from model import inpainting 11 | from utils import * 12 | 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Training the inpainting model.....') 16 | parser.add_argument('--mol_p', type=str, help='the path of molecular dataset', default='./data/synthons/data.csv') 17 | parser.add_argument('--embed_dim ', type=int, default=64, help='embedding_dim') 18 | parser.add_argument('--hid_dim ', type=int, default=256, help='hidden_dim') 19 | parser.add_argument('--skip_connection', type=int,help='skip connection', nargs='+', default=[0,1,2,3,4]) 20 | parser.add_argument('--attention', type=int,help='attention mechanism', nargs='+', default=[1]) 21 | parser.add_argument('--lr',type=float,help='learning rate',default=1e-4) 22 | parser.add_argument('--epoch',type=int,help='training epoches',default=80) 23 | 24 | 25 | args = parser.parse_args() 26 | 27 | 28 | 29 | 30 | 31 | 32 | #build training dataset 33 | df_m = pd.read_csv(args.mol_p) 34 | df_m['SMILES'] = df_m['SMILES'].str.replace('Cl', 'X').replace('Br', 'Y').replace('[nH]', 'Z') 35 | smiles_db = df_m["SMILES"].tolist() 36 | smiles_db=filter_invalid_molecules(smiles_db) 37 | char_set=tokened(args.mol_p) 38 | char_to_idx = {char: idx for idx, char in enumerate(char_set)} 39 | vocab_size = len(char_to_idx) 40 | 41 | 42 | def collate_fn(batch, vocab_size): 43 | lefts, rights, targets = zip(*batch) 44 | lefts = nn.utils.rnn.pad_sequence(lefts, batch_first=True, padding_value=vocab_size) 45 | rights = nn.utils.rnn.pad_sequence(rights, batch_first=True, padding_value=vocab_size) 46 | targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=vocab_size) 47 | return lefts, rights, targets 48 | 49 | 50 | 51 | # Training 52 | def train(model, train_loader): 53 | 54 | model.train() 55 | mse = nn.MSELoss(reduction = 'none').cuda(0) 56 | 57 | rec_loss = 0 58 | cons_loss = 0 59 | optimizer= optim.Adam(model.parameters(), lr=args.lr) 60 | 61 | for batch_idx, (l_mol, r_mol, m_mol) in enumerate(train_loader): 62 | 63 | batchSize = l_mol.shape[0] 64 | mol_len = l_mol.shape[1] 65 | l_mol, r_mol, m_mol = Variable(l_mol).cuda(0), Variable(r_mol).cuda(0), Variable(m_mol).cuda(0) 66 | 67 | ## Generate mid-molecules 68 | mol_pred, F_lmol, F_rmol = model(l_mol, r_mol) 69 | 70 | # Reconstruction Loss 71 | weight = gaussian_weight(batchSize, mol_len, device=0) 72 | mask = weight + weight.flip(3) 73 | rec_loss = mask * mse(mol_pred, m_mol).mean() * batchSize 74 | 75 | #Consistency Loss 76 | cons_loss = (mse(F_lmol[0], F_rmol[0]) + mse(F_lmol[1], F_rmol[1]) + mse(F_lmol[2], F_rmol[2])).mean() * batchSize 77 | 78 | gen_loss = rec_loss + cons_loss 79 | rec_loss += rec_loss.data 80 | cons_loss += cons_loss.data 81 | 82 | 83 | if (batch_idx % 3) != 0: 84 | optimizer.zero_grad() 85 | gen_loss.backward() 86 | optimizer.step() 87 | 88 | 89 | def main(): 90 | model =inpainting(vocab_size, embedding_dim=args.embed_dim, hidden_dim=args.hid_dim,skip=args.skip_connection, attention=args.attention).cuda(0) 91 | 92 | dataset = SMILESDataset(smiles_db, char_to_idx) 93 | train_loader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=lambda x: collate_fn(x, vocab_size)) 94 | 95 | for epoch in range(args.epoch): 96 | train(model, train_loader) 97 | torch.save(model.state_dict(), './data/inpainting.pth') 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from tqdm import tqdm 5 | 6 | 7 | import pandas as pd 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.utils.data import Dataset, DataLoader 13 | from torch.autograd import Variable 14 | 15 | from vina import Vina 16 | 17 | from openbabel import pybel as pyb 18 | from openbabel import openbabel 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdChemReactions, Descriptors, rdMolTransforms, rdMolDescriptors, rdmolops 22 | 23 | 24 | 25 | 26 | def split_molecule(smiles): 27 | 28 | molecule = Chem.MolFromSmiles(smiles) 29 | if molecule is None: 30 | return [] 31 | 32 | Chem.SanitizeMol(molecule) 33 | synthons = [molecule] 34 | amide_smarts = "*NC(*)=O" 35 | triazole_smarts = "*c1cn(*)nn1" 36 | 37 | changed = True 38 | while changed: 39 | changed = False 40 | new_synthons = [] 41 | for synthon in synthons: 42 | amide_bonds = synthon.GetSubstructMatches(Chem.MolFromSmarts(amide_smarts)) 43 | 44 | triazole_rings = synthon.GetSubstructMatches(Chem.MolFromSmarts(triazole_smarts)) 45 | 46 | if amide_bonds or triazole_rings: 47 | changed = True 48 | if amide_bonds: 49 | reaction_smarts = "[*:1]C(=O)N[*:2]>>[*:1]C(=O)O[3*].[*:2]N[3*]" 50 | elif triazole_rings: 51 | reaction_smarts = r"[*:1]c1cn([*:2])nn1>>[1*]/C=C(\[2*])[*:1].[2*]/N=N\N([1*])[*:2]" 52 | reaction = rdChemReactions.ReactionFromSmarts(reaction_smarts) 53 | try: 54 | products = reaction.RunReactants((synthon,)) 55 | for product in products: 56 | for mol in product: 57 | new_synthons.append(mol) 58 | except: 59 | pass 60 | else: 61 | new_synthons.append(synthon) 62 | synthons = new_synthons 63 | 64 | final_synthons = [Chem.MolToSmiles(mol) for mol in synthons if mol is not None] 65 | return final_synthons 66 | 67 | 68 | def synthon_prepare(smi_path): 69 | all_synthons = set() 70 | data=pd.read_csv(smi_path) 71 | for smiles in tqdm(data['SMILES']): 72 | synthons = split_molecule(smiles) 73 | all_synthons.update(synthons) 74 | pd.DataFrame({'Synthons': list(all_synthons)}).to_csv('synthons.csv', index=False) 75 | 76 | 77 | def tokened(file_path): 78 | df = pd.read_csv(file_path,usecols=[0]) 79 | df.columns = ['SMILES'] 80 | df['SMILES'] = df['SMILES'].str.replace('Cl', 'X').replace('Br', 'Y').replace('[nH]', 'Z') 81 | all_tokens = set() 82 | for smiles in df['SMILES']: 83 | tokens = set(smiles) 84 | all_tokens.update(tokens) 85 | 86 | return all_tokens 87 | 88 | def ind2smi (output, idx_to_char): 89 | # Get the indices of the maximum values along the last dimension 90 | indices = torch.argmax(output, dim=-1) 91 | 92 | # Convert indices to characters 93 | smiles_list = [] 94 | for seq in indices: 95 | smiles = ''.join(idx_to_char[idx.item()] for idx in seq) 96 | smiles_list.append(smiles) 97 | return smiles_list 98 | 99 | def make_layers(in_channel, out_channel, kernel_size, stride, padding, dilation=1, bias=True, norm=True, activation=True, is_relu=False): 100 | layer = [] 101 | layer.append(nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)) 102 | if norm: 103 | layer.append(nn.InstanceNorm1d(out_channel, affine=True)) 104 | if activation: 105 | if is_relu: 106 | layer.append(nn.ReLU()) 107 | else: 108 | layer.append(nn.LeakyReLU(negative_slope=0.2)) 109 | return nn.Sequential(*layer) 110 | 111 | def make_layers_transpose(in_channel, out_channel, kernel_size, stride, padding, dilation=1, bias=True, norm=True, activation=True, is_relu=False): 112 | layer = [] 113 | layer.append(nn.ConvTranspose1d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)) 114 | if norm: 115 | layer.append(nn.InstanceNorm1d(out_channel, affine=True)) 116 | if activation: 117 | if is_relu: 118 | layer.append(nn.ReLU()) 119 | else: 120 | layer.append(nn.LeakyReLU(negative_slope=0.2)) 121 | return nn.Sequential(*layer) 122 | 123 | def smiles_to_mol(smiles): 124 | return Chem.MolFromSmiles(smiles) 125 | 126 | def split_molecule(mol): 127 | def get_num_atoms(fragment): 128 | return fragment.GetNumAtoms() 129 | 130 | if mol is None: 131 | raise ValueError("Invalid molecule input.") 132 | 133 | cut_bonds = [] 134 | for bond in mol.GetBonds(): 135 | if not bond.IsInRing(): 136 | cut_bonds.append(bond.GetIdx()) 137 | 138 | def atom_diff(frag_atoms): 139 | return max(frag_atoms) - min(frag_atoms) 140 | 141 | best_fragments = None 142 | smallest_diff = float('inf') 143 | best_cut_bonds = None 144 | 145 | # Try to find the best cut 146 | for i in range(len(cut_bonds)): 147 | for j in range(i + 1, len(cut_bonds)): 148 | try: 149 | frags = Chem.FragmentOnBonds(mol, [cut_bonds[i], cut_bonds[j]], addDummies=True, dummyLabels=[(0, 0), (1, 1)]) 150 | frags = Chem.GetMolFrags(frags, asMols=True, sanitizeFrags=True) 151 | 152 | if len(frags) == 3: 153 | frag_atoms = [get_num_atoms(frag) for frag in frags] 154 | diff = atom_diff(frag_atoms) 155 | if diff < smallest_diff: 156 | smallest_diff = diff 157 | best_fragments = frags 158 | best_cut_bonds = [cut_bonds[i], cut_bonds[j]] 159 | except Exception as e: 160 | print(f"Error processing bonds {cut_bonds[i]} and {cut_bonds[j]}: {e}") 161 | continue 162 | 163 | if best_fragments is None: 164 | raise ValueError("Could not find a suitable cut to split the molecule into three parts.") 165 | 166 | # Determine which fragment is left, middle, and right 167 | atom_indices = [frag.GetAtoms()[0].GetIdx() for frag in best_fragments] 168 | sorted_indices = sorted(range(len(atom_indices)), key=lambda k: atom_indices[k]) 169 | 170 | left = best_fragments[sorted_indices[0]] 171 | middle = best_fragments[sorted_indices[1]] 172 | right = best_fragments[sorted_indices[2]] 173 | 174 | return left, middle, right 175 | 176 | 177 | class SMILESDataset(Dataset): 178 | def __init__(self, smiles_list, char_to_idx): 179 | self.smiles_list = smiles_list 180 | self.char_to_idx = char_to_idx 181 | 182 | def __len__(self): 183 | return len(self.smiles_list) 184 | 185 | def __getitem__(self, idx): 186 | smiles = self.smiles_list[idx] 187 | mol = smiles_to_mol(smiles) 188 | 189 | left, middle, right = split_molecule(mol) 190 | 191 | left_smiles = Chem.MolToSmiles(left) 192 | right_smiles = Chem.MolToSmiles(right) 193 | middle_smiles = Chem.MolToSmiles(middle) 194 | 195 | 196 | left_indices = smiles_to_indices(left_smiles, self.char_to_idx) 197 | right_indices = smiles_to_indices(right_smiles, self.char_to_idx) 198 | middle_indices = smiles_to_indices(middle_smiles, self.char_to_idx) 199 | 200 | 201 | 202 | left_indices = pad_sequence(left_indices) 203 | right_indices = pad_sequence(right_indices) 204 | middle_indices = pad_sequence(middle_indices) 205 | 206 | return (torch.tensor(left_indices, dtype=torch.long), 207 | torch.tensor(right_indices, dtype=torch.long), 208 | torch.tensor(middle_indices, dtype=torch.long)) 209 | 210 | def pad_sequence(seq): 211 | seq += [0] * (100 - len(seq)) 212 | return seq[:100] 213 | def smiles_to_indices(smiles, char_to_idx): 214 | return [char_to_idx[char] for char in smiles] 215 | 216 | 217 | def split_input_synthons(smiles): 218 | mol = Chem.MolFromSmiles(smiles) 219 | if not mol: 220 | raise ValueError("Invalid SMILES string") 221 | 222 | # Find potential cutting points (bonds not in rings) 223 | cut_bonds = [] 224 | for bond in mol.GetBonds(): 225 | if not bond.IsInRing(): 226 | cut_bonds.append(bond.GetIdx()) 227 | 228 | if len(cut_bonds) < 2: 229 | raise ValueError("Not enough non-ring bonds to cut") 230 | 231 | # Randomly select two bonds to cut 232 | cut1, cut2 = random.sample(cut_bonds, 2) 233 | while cut2 <= cut1: 234 | cut1, cut2 = random.sample(cut_bonds, 2) 235 | 236 | # Cut the molecule 237 | frags = Chem.FragmentOnBonds(mol, [cut1, cut2], addDummies=True, dummyLabels=[(0, 0), (1, 1)]) 238 | frags = Chem.GetMolFrags(frags, asMols=True, sanitizeFrags=True) 239 | 240 | if len(frags) != 3: 241 | raise ValueError("Failed to split SMILES into three parts") 242 | 243 | return frags 244 | 245 | def process_synthons_dataset(smi, char_to_idx, vocab_size): 246 | 247 | try: 248 | frags = split_input_synthons(smi) 249 | left, middle, right = [Chem.MolToSmiles(frag) for frag in frags] 250 | except ValueError as e: 251 | print(f"Skipping SMILES {smi}: {e}") 252 | return None, None, None 253 | 254 | left_indices = smiles_to_indices(left, char_to_idx) 255 | middle_indices = smiles_to_indices(middle, char_to_idx) 256 | right_indices = smiles_to_indices(right, char_to_idx) 257 | 258 | left_tensor= pad_sequence(left_indices) 259 | middle_tensor= pad_sequence(middle_indices) 260 | right_tensor= pad_sequence(right_indices) 261 | 262 | left_tensor = torch.tensor(left_tensor) 263 | middle_tensor = torch.tensor(middle_tensor) 264 | right_tensor = torch.tensor(right_tensor) 265 | 266 | left_final = nn.utils.rnn.pad_sequence([left_tensor], batch_first=True, padding_value=vocab_size) 267 | middle_final = nn.utils.rnn.pad_sequence([middle_tensor], batch_first=True, padding_value=vocab_size) 268 | right_final = nn.utils.rnn.pad_sequence([right_tensor], batch_first=True, padding_value=vocab_size) 269 | 270 | return left_final, middle_final, right_final 271 | 272 | 273 | 274 | 275 | 276 | def filter_invalid_molecules(smiles_list): 277 | 278 | valid_smiles = [] 279 | for smiles in smiles_list: 280 | mol = Chem.MolFromSmiles(smiles) 281 | if mol is not None: 282 | valid_smiles.append(smiles) 283 | return valid_smiles 284 | 285 | def reduce_sum(x, axis=None, keepdim=False): 286 | if not axis: 287 | axis = range(len(x.shape)) 288 | for i in sorted(axis, reverse=True): 289 | x = torch.sum(x, dim=i, keepdim=keepdim) 290 | return x 291 | 292 | def cos_function_weight(batchSize, imgSize, device): 293 | weight = torch.ones((imgSize, imgSize)) 294 | for i in range(imgSize): 295 | weight[:, i] = (1. + math.cos(math.pi * i / float(imgSize-1))) * 0.5 296 | weight = weight.view(1,1,imgSize,imgSize).repeat(batchSize,1,1,1) 297 | return Variable(weight).cuda(device) 298 | 299 | def gaussian_weight(size1, size2, device=0): 300 | weight = torch.ones((size1, size2)) 301 | var = (size2/4)**2 302 | for i in range(size2): 303 | weight[:, i] = math.exp(-(float(i))**2/(2*var)) 304 | weight = weight.view(size1,size2) 305 | return Variable(weight).cuda(device) 306 | 307 | def gaussian_bias(size, device=0): 308 | bias = torch.ones((size)) 309 | var = (size/4)**2 310 | for i in range(size): 311 | bias[i] = math.exp(-(float(i))**2/(2*var)) 312 | return Variable(bias).cuda(device) 313 | 314 | 315 | 316 | def padding_smi(smiles_seq): 317 | seq += [0] * (100 - len(smiles_seq)) 318 | padded_seq = seq[:100] 319 | return torch.tensor(padded_seq, dtype=torch.long) 320 | 321 | 322 | 323 | def assemble_smiles(left, middle, right): 324 | 325 | left_mol = Chem.MolFromSmiles(left) 326 | middle_mol = Chem.MolFromSmiles(middle) 327 | right_mol = Chem.MolFromSmiles(right) 328 | 329 | 330 | if left_mol is None or middle_mol is None or right_mol is None: 331 | raise ValueError("One of the SMILES strings could not be converted to a molecule.") 332 | 333 | 334 | left_frag = Chem.MolToSmiles(left_mol, isomericSmiles=True) 335 | middle_frag = Chem.MolToSmiles(middle_mol, isomericSmiles=True) 336 | right_frag = Chem.MolToSmiles(right_mol, isomericSmiles=True) 337 | 338 | 339 | combined_frag = left_frag + '.' + middle_frag + '.' + right_frag 340 | combined_mol = Chem.MolFromSmiles(combined_frag) 341 | 342 | 343 | combined_mol = rdmolops.CombineMols(left_mol, middle_mol) 344 | combined_mol = rdmolops.CombineMols(combined_mol, right_mol) 345 | 346 | assembled_smiles = Chem.MolToSmiles(combined_mol, isomericSmiles=True) 347 | return assembled_smiles 348 | 349 | def assemble_smiles_with_symbols(smiles1, smiles2, smiles3): 350 | mol1 = Chem.MolFromSmiles(smiles1) 351 | mol2 = Chem.MolFromSmiles(smiles2) 352 | mol3 = Chem.MolFromSmiles(smiles3) 353 | 354 | # Combine molecules without modifying them 355 | combined = Chem.CombineMols(mol1, mol2) 356 | combined = Chem.CombineMols(combined, mol3) 357 | 358 | # Manual assembly with mark replacement 359 | combined_smiles = smiles1.replace("[0*]", "[0*]") + smiles2.replace("[0*]", "").replace("[1*]", "") + smiles3.replace("[1*]", "[1*]") 360 | 361 | return combined_smiles 362 | 363 | def disassemble_smiles_with_symbols(combined_smiles): 364 | # Use the markers to find the split points 365 | parts = combined_smiles.split('[0*]') 366 | left = parts[0] + '[0*]' 367 | remaining = parts[1].split('[1*]') 368 | middle = '[0*]' + remaining[0] + '[1*]' 369 | right = '[1*]' + remaining[1] 370 | 371 | return left, middle, right 372 | 373 | def pdb2pdbqt(input_pdb, output_pdbqt): 374 | obConversion = openbabel.OBConversion() 375 | obConversion.SetInAndOutFormats("pdb", "pdbqt") 376 | mol = openbabel.OBMol() 377 | obConversion.ReadFile(mol, input_pdb) 378 | obConversion.WriteFile(mol, output_pdbqt) 379 | 380 | 381 | def vina_dock(lig,save_path='./log/docked.pdbqt'): 382 | v = Vina(sf_name='vina') 383 | v.set_receptor('./data/protein.pdbqt') 384 | mymol = pyb.readstring("smi", lig) 385 | mymol.make3D() 386 | mymol.write(format='pdbqt', filename='./log/lig.pdbqt',overwrite=True) 387 | v.set_ligand_from_file('./log/lig.pdbqt') 388 | v.compute_vina_maps(center=[15.190, 53.903, 16.917], box_size=[20, 20, 20]) 389 | v.dock(exhaustiveness=32, n_poses=20) 390 | v.write_poses(save_path, n_poses=1, overwrite=True) 391 | return float(os.popen('v.write_poses(save_path, n_poses=1, overwrite=True)')[1].split(' (kcal/mol)')[0].split()[0]) --------------------------------------------------------------------------------