├── .gitignore ├── Data ├── raw_data.md └── train_data │ └── readme.md ├── LICENSE ├── README.md ├── codonbert_env.yaml ├── data ├── example_data │ ├── codon_seq.fasta │ ├── codon_seq.npy │ ├── predict.fasta │ └── protein_seq.fasta └── test_data │ └── protein_seq.fasta ├── figures └── 2024.01.24-figure1.jpg ├── models └── kidney_1_1_CodonBert_model_20230726_320_model_param.pt ├── predict.py ├── scripts ├── codon_bert_pytorch.py ├── config_function.py ├── data_formatting.py ├── data_processing.py ├── fasta_to_npy.py └── get_metrics.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .logs 3 | *.pyc -------------------------------------------------------------------------------- /Data/raw_data.md: -------------------------------------------------------------------------------- 1 | There are three files used in downstream analysis. The name and download link is shown below. 2 | - [gencode.v43.pc_transcripts.fa.gz](https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_43/gencode.v43.pc_transcripts.fa.gz) 3 | - [gencode.v43.pc_translations.fa.gz](https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_43/gencode.v43.pc_translations.fa.gz) 4 | - [transcript_rna_tissue.tsv](https://www.proteinatlas.org/download/transcript_rna_tissue.tsv.zip) 5 | 6 | For `data/train_data` generated by data_processing.py -------------------------------------------------------------------------------- /Data/train_data/readme.md: -------------------------------------------------------------------------------- 1 | # generated by data_processing.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FPPGroup 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CodonBERT 2 | 3 | This is the code for the article [_CodonBert: a BERT-based architecture tailored for codon optimization using the cross-attention mechanism_](https://doi.org/10.1093/bioinformatics/btae330). CodonBERT is a flexible deep-learning model for codon optimization, which is inspired by ProteinBERT (Brandes et al., 2022). We made crucial modifications to build the CodonBERT. As for architecutre, (1) the right-side network was rebuilt to match the encoder on the left-side; (2) codon tokens are now used as both keys and values in the cross-attention mechanism, while the protein sequence serves as the query. In this way, CodonBERT learns codon usage preferences and contextual combination preferences via randomly masked codon tokens. 4 | 5 | CodonBERT requires amino acid sequences in FASTA format as input, and predicted the optimizaed codon sequences. Four trained models based on high-TPM data (with various proporations of JCAT-optimized sequences) are provided in this repository. The users can directly use `predict.py` to conduct codon optimization. Notably, we provided the `train.py` for developers to train a cusom model on specific data. In current version, the hyperparameters of model can only be modified in the source code. The graphic user interface is under developing till Apr. 2024. In the meantime, we're processing the tissue-specific data to realize a tissue-speific optimization tool. 6 | 7 | ![figure](https://github.com/FPPGroup/CodonBERT/blob/main/figures/2024.01.24-figure1.jpg) 8 | 9 | 10 | ## Table of Contents 11 | 12 | - [CodonBERT](#codonbert) 13 | - [Table of Contents](#table-of-contents) 14 | - [Installation](#installation) 15 | - [Dependencies](#dependencies) 16 | - [Usage](#usage) 17 | - [For users](#for-users) 18 | - [For developers](#for-developers) 19 | - [Citation](#citation) 20 | 21 | 22 | ## Installation 23 | We recommend `conda` to manage the computing environment. Here, the model training and prediction is based on Python and PyTorch. The calculation of CAI and MFE is based on EMBOSS v6.6.0 (Olson, 2002) and ViennaRNA v2.6.4 (Lorenz et al., 2011). The environment has been test on Ubuntu environment. As for MacOS, the EMBOSS and ViennaRNA can't be installed directly. 24 | 25 | Here are dependencies: 26 | ``` 27 | conda create -n codonbert_env python=3.10 -y 28 | conda activate codonbert_env 29 | conda install bioconda::emboss # not for macos-arm64 30 | pip install torch --index-url https://download.pytorch.org/whl/cu118 # the users should check the version of pytorch 31 | pip install ViennaRNA==2.6.4 biopython==1.81 einops==0.6.0 numpy==1.26.4 pandas==2.2.0 scikit-learn==1.2.1 tensorboardx==2.6 tqdm==4.65.0 32 | ``` 33 | 34 | 35 | If using `conda`, users can run the line commands below: 36 | 37 | ```bash 38 | conda env create -f codonbert_env.yaml -n codonbert_env 39 | conda activate codonbert_env 40 | ``` 41 | 42 | Download the source code: 43 | ```bash 44 | git clone https://github.com/FPPGroup/CodonBERT.git 45 | cd CodonBERT 46 | ``` 47 | 48 | ## Usage 49 | The code in this repository can be used for model training, prediction. 50 | 51 | ### Codon Optimization 52 | 53 | For `predict.py`, the user only needs to ensure the paths of the `*.pt` file (model weights), the protein sequence file, and the output mRNA file. 54 | 55 | ```bash 56 | python predict.py -m $path_to_MODEL_WEIGHTS -i $path_to_Amino_Acid_FASTA -o $path_to_output 57 | ``` 58 | 59 | The weights of four trained models were stored in `models/kidney_1_1_CodonBert_model_20230726_320_model_param.pt`. Users can test the code by the following commands: 60 | 61 | ```bash 62 | ## test commmand line 63 | python predict.py -m models/kidney_1_1_CodonBert_model_20230726_320_model_param.pt -f data/example_data/test_example.fasta -o data/example_data/optimized.fasta 64 | ``` 65 | 66 | Moreover, we've already integrated the CAI and MFE calculation in our repository. Users can assess the numeric metrics of optimized codon sequences. 67 | 68 | ```bash 69 | python scripts/get_metrics.py -f data/example_data/optimized.fasta -o data/example_data/optimized_metrics.csv 70 | ``` 71 | 72 | 73 | ### For developers 74 | 75 | CodonBERT is supposed to be trained easily and flexibly. Thus, developers only need to foucs on data collection. Developers can use and revise `train.py` to retrain a codon optimization model. Detailed architecture is stroed at `scripts/codon_bert_pytorch.py`. And the usual hyperparameters can be modified in `train.py`. Users can contact us directly for further help. 76 | 77 | ```bash 78 | python train.py -t $path_trainset_fasta -v $path_validset_fasta -o $path_to_save_model_weights 79 | ``` 80 | 81 | 82 | ### Data processing in our paper 83 | 1. tissue and TPM were screened 84 | - Select a codon sequence in the script for a specific organization and the condition (TPM>5) that the TPM value meets 85 | 86 | 2. Length distribution statistics are performed and sequence filtering is performed according to length 87 | - Based on the results of 01, codon sequences ranging in length from 200 to 2000 are screened 88 | 89 | 3. MFE and CAI indexes of codon sequences were calculated and counted 90 | - MFE and CAI calculations are saved into csv and scatter plots are drawn: 91 | 92 | - Based on the results of 02 and the path of the environment, the CAI and MFE of the sequence are calculated and stored as a csv 93 | 94 | - Scatterplot is drawn according to CAI and MFE values and stored in pdf format 95 | 96 | - Load the csv saved in the previous step and draw the edge histogram: 97 | 98 | - According to the results of 3.1, histograms are drawn in CAI and MFE directions, and the scatterplot together form the edge histogram 99 | 100 | 4. According to the calculation and statistical results of MFE and CAI indicators, the selection was carried out 101 | 102 | - Based on the results of 03, the script selects codon sequences with CAI and MFE values that meet the conditions (CAI>0.7, MFE<-200) 103 | 104 | 5. The idea of JCAT was used to optimize the mRNA of the data set obtained in the previous step 105 | 106 | - According to the results of 04, the codon sequence is converted to the amino acid sequence 107 | 108 | - The amino acid sequence was optimized using Jcat method 109 | 110 | 6. The MFE and CAI statistics and screening of JCAT optimization results were carried out (To run the script: 03~04) 111 | 112 | - According to the results of 2005, codon sequences matching (CAI>0.7, MFE<-200) in JCAT optimization results were screened 113 | 114 | 7. Build training sets and verification sets 115 | 116 | - According to the results of 2006, the first 1w codon sequences are used to build the training set, and the remaining 973 sequences are used as verification sets in the model training process 117 | 118 | - Four training sets were constructed according to different ratios of 1:0, 1:0.2, 1:0.5 and 1:1 between the screened real sequences and JCAT optimization results 119 | 120 | ```bash 121 | python ./scripts/data_preocessing.py -t $path_transcript_rna_tissue_tsv_file -l $path_gencode_v43_pc_translations_fa_gz_file -c $path_gencode_v43_pc_transcripts_fa_gz_file -o $path_output 122 | ``` 123 | 124 | 125 | 126 | ## Citation 127 | 128 | Zilin Ren, Lili Jiang, Yaxin Di, Dufei Zhang, Jianli Gong, Jianting Gong, Qiwei Jiang, Zhiguo Fu, Pingping Sun, Bo Zhou, Ming Ni, CodonBERT: a BERT-based architecture tailored for codon optimization using the cross-attention mechanism, Bioinformatics, Volume 40, Issue 7, July 2024, btae330, https://doi.org/10.1093/bioinformatics/btae330 129 | 130 | ## Reference 131 | Brandes,N. et al. (2022) ProteinBERT: a universal deep-learning model of protein sequence and function. Bioinformatics, 38, 2102–2110. 132 | 133 | Lorenz,R. et al. (2011) ViennaRNA Package 2.0. Algorithms for Molecular Biology, 6, 26. 134 | 135 | Olson,S.A. (2002) EMBOSS opens up sequence analysis. European Molecular Biology Open Software Suite. Brief Bioinform, 3, 87–91. 136 | -------------------------------------------------------------------------------- /codonbert_env.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: codonbert_env 3 | channels: 4 | - gurobi 5 | - pytorch 6 | - bioconda/label/cf201901 7 | - bioconda 8 | - defaults 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 12 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/fastai/ 13 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 14 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/ 15 | dependencies: 16 | - _libgcc_mutex=0.1=main 17 | - _openmp_mutex=5.1=1_gnu 18 | - blas=1.0=mkl 19 | - bottleneck=1.3.5=py310ha9d4c09_0 20 | - brotlipy=0.7.0=py310h7f8727e_1002 21 | - bzip2=1.0.8=h7b6447c_0 22 | - ca-certificates=2023.01.10=h06a4308_0 23 | - certifi=2023.5.7=py310h06a4308_0 24 | - cffi=1.15.1=py310h74dc2b5_0 25 | - codonw=1.4.4=hec16e2b_4 26 | - cryptography=39.0.1=py310h9ce1e76_0 27 | - cudatoolkit=10.2.89=hfd86e86_1 28 | - emboss=6.6.0=h6debe1e_0 29 | - expat=2.4.4=h295c915_0 30 | - ffmpeg=4.3=hf484d3e_0 31 | - fontconfig=2.13.1=h6c09931_0 32 | - freetype=2.11.0=h70c0345_0 33 | - giflib=5.2.1=h5eee18b_3 34 | - gmp=6.2.1=h295c915_3 35 | - gnutls=3.6.15=he1e5248_0 36 | - gurobi=9.5.2=py310_0 37 | - icu=58.2=he6710b0_3 38 | - idna=3.4=py310h06a4308_0 39 | - intel-openmp=2021.4.0=h06a4308_3561 40 | - jpeg=9e=h7f8727e_0 41 | - lame=3.100=h7b6447c_0 42 | - lcms2=2.12=h3be6417_0 43 | - ld_impl_linux-64=2.38=h1181459_1 44 | - lerc=3.0=h295c915_0 45 | - libdeflate=1.8=h7f8727e_5 46 | - libffi=3.3=he6710b0_2 47 | - libgcc-ng=11.2.0=h1234567_1 48 | - libgd=2.3.3=h695aa2c_1 49 | - libgomp=11.2.0=h1234567_1 50 | - libiconv=1.16=h7f8727e_2 51 | - libidn2=2.3.4=h5eee18b_0 52 | - libpng=1.6.37=hbc83047_0 53 | - libstdcxx-ng=11.2.0=h1234567_1 54 | - libtasn1=4.19.0=h5eee18b_0 55 | - libtiff=4.4.0=hecacb30_0 56 | - libunistring=0.9.10=h27cfd23_0 57 | - libuuid=1.0.3=h7f8727e_2 58 | - libwebp=1.2.2=h55f646e_0 59 | - libwebp-base=1.2.2=h7f8727e_0 60 | - libxml2=2.9.14=h74e7548_0 61 | - lz4-c=1.9.3=h295c915_1 62 | - mkl=2021.4.0=h06a4308_640 63 | - mkl-service=2.4.0=py310h7f8727e_0 64 | - mkl_fft=1.3.1=py310hd6ae3a3_0 65 | - mkl_random=1.2.2=py310h00e6091_0 66 | - ncurses=6.3=h5eee18b_3 67 | - nettle=3.7.3=hbbd107a_1 68 | - numexpr=2.8.3=py310hcea2de6_0 69 | - openh264=2.1.1=h4ff587b_0 70 | - openssl=1.1.1t=h7f8727e_0 71 | - packaging=21.3=pyhd3eb1b0_0 72 | - pip=22.1.2=py310h06a4308_0 73 | - pycparser=2.21=pyhd3eb1b0_0 74 | - pyopenssl=23.0.0=py310h06a4308_0 75 | - pyparsing=3.0.9=py310h06a4308_0 76 | - pysocks=1.7.1=py310h06a4308_0 77 | - python=3.10.4=h12debd9_0 78 | - python-dateutil=2.8.2=pyhd3eb1b0_0 79 | - pytorch=1.12.0=py3.10_cuda10.2_cudnn7.6.5_0 80 | - pytorch-mutex=1.0=cuda 81 | - pytz=2022.1=py310h06a4308_0 82 | - readline=8.1.2=h7f8727e_1 83 | - setuptools=63.4.1=py310h06a4308_0 84 | - six=1.16.0=pyhd3eb1b0_1 85 | - sqlite=3.39.2=h5082296_0 86 | - tk=8.6.12=h1ccaba5_0 87 | - torchaudio=0.12.0=py310_cu102 88 | - torchvision=0.13.0=py310_cu102 89 | - tzdata=2022a=hda174b7_0 90 | - viennarna=2.3.3=hfc679d8_2 91 | - wheel=0.37.1=pyhd3eb1b0_0 92 | - xz=5.2.5=h7f8727e_1 93 | - zlib=1.2.12=h7f8727e_2 94 | - zstd=1.5.2=ha4553b6_0 95 | - pip: 96 | - absl-py==1.2.0 97 | - asttokens==2.0.8 98 | - astunparse==1.6.3 99 | - backcall==0.2.0 100 | - bertviz==1.4.0 101 | - biopython==1.81 102 | - boto3==1.28.49 103 | - botocore==1.31.49 104 | - cachetools==5.2.0 105 | - charset-normalizer==2.1.1 106 | - contourpy==1.0.7 107 | - cycler==0.11.0 108 | - debugpy==1.6.3 109 | - decorator==5.1.1 110 | - einops==0.6.0 111 | - entrypoints==0.4 112 | - executing==1.0.0 113 | - filelock==3.12.4 114 | - flatbuffers==2.0.7 115 | - fonttools==4.39.4 116 | - fsspec==2023.9.1 117 | - gast==0.4.0 118 | - google-auth==2.11.1 119 | - google-auth-oauthlib==0.4.6 120 | - google-pasta==0.2.0 121 | - grpcio==1.49.1 122 | - h5py==3.7.0 123 | - huggingface-hub==0.17.1 124 | - ipykernel==6.15.3 125 | - ipython==8.5.0 126 | - jedi==0.18.1 127 | - jmespath==1.0.1 128 | - joblib==1.2.0 129 | - jupyter-client==7.3.5 130 | - jupyter-core==4.11.1 131 | - keras==2.10.0 132 | - keras-preprocessing==1.1.2 133 | - kiwisolver==1.4.4 134 | - libclang==14.0.6 135 | - markdown==3.4.1 136 | - markupsafe==2.1.1 137 | - matplotlib==3.7.1 138 | - matplotlib-inline==0.1.6 139 | - nest-asyncio==1.5.5 140 | - numpy==1.23.5 141 | - oauthlib==3.2.1 142 | - opt-einsum==3.3.0 143 | - pandas==1.5.3 144 | - parso==0.8.3 145 | - pexpect==4.8.0 146 | - pickleshare==0.7.5 147 | - pillow==9.4.0 148 | - prompt-toolkit==3.0.31 149 | - protobuf==3.20.3 150 | - psutil==5.9.2 151 | - ptyprocess==0.7.0 152 | - pure-eval==0.2.2 153 | - pyasn1==0.4.8 154 | - pyasn1-modules==0.2.8 155 | - pygments==2.13.0 156 | - pyyaml==6.0.1 157 | - pyzmq==24.0.0 158 | - regex==2023.8.8 159 | - requests==2.28.1 160 | - requests-oauthlib==1.3.1 161 | - rsa==4.9 162 | - s3transfer==0.6.2 163 | - safetensors==0.3.3 164 | - scikit-learn==1.2.1 165 | - scipy==1.10.1 166 | - seaborn==0.12.2 167 | - sentencepiece==0.1.99 168 | - stack-data==0.5.0 169 | - tensorboard==2.12.0 170 | - tensorboard-data-server==0.7.0 171 | - tensorboard-plugin-wit==1.8.1 172 | - tensorboardx==2.6 173 | - tensorflow-estimator==2.10.0 174 | - tensorflow-io-gcs-filesystem==0.27.0 175 | - termcolor==2.0.1 176 | - threadpoolctl==3.1.0 177 | - tokenizers==0.13.3 178 | - tornado==6.2 179 | - tqdm==4.65.0 180 | - traitlets==5.4.0 181 | - transformers==4.33.2 182 | - typing-extensions==4.3.0 183 | - urllib3==1.26.12 184 | - wcwidth==0.2.5 185 | - werkzeug==2.2.2 186 | - wrapt==1.14.1 187 | -------------------------------------------------------------------------------- /data/example_data/codon_seq.fasta: -------------------------------------------------------------------------------- 1 | >RSV-Pre-F-GCN4 2 | ATGGAGCTGCTGATCCTGAAGACCAACGCCATCACCGCCATCCTGGCCGCGGTGACCCTGTGTTTCGCCAGCAGCCAGAACATCACCGAGGAGTTCTACCAGAGCACCTGCAGCGCGGTGAGCAAGGGCTACCTGAGCGCCCTGCGCACAGGCTGGTACACATCTGTGATCACCATCGAGCTGAGCAACATCAAGGAGAACAAGTGCAACGGCACCGACGCCAAGGTGAAGCTGATCAAGCAGGAGCTGGACAAGTACAAGTCGGCCGTGACTGAGCTGCAGCTGCTGATGCAGAGCACCCCCGCCACCAACAACAAGTTCCTGGGCTTCCTGCAGGGCGTGGGCAGCGCCATCGCCAGTGGCATCGCGGTCAGCAAAGTGCTGCACCTGGAAGGAGAAGTGAACAAGATCAAGTCTGCCCTGCTGAGCACCAACAAGGCTGTGGTGTCTCTGAGCAACGGCGTGAGCGTGCTGACCAGCAAGGTGCTGGACCTGAAGAACTACATCGACAAGCAGCTGCTGCCCATCGTGAACAAGCAGAGCTGCAGCATCAGCAACATCGAGACCGTGATCGAGTTTCAGCAGAAGAACAACCGCCTGCTGGAGATCACCCGCGAGTTCAGTGTGAACGCTGGAGTGACCACCCCTGTGTCCACCTACATGCTGACCAACAGCGAGCTGCTGAGCCTGATCAACGACATGCCCATCACCAACGACCAGAAGAAGCTGATGAGCAACAACGTGCAGATCGTGCGGCAGCAGAGCTACAGCATCATGAGCATCATCAAGGAGGAGGTGCTGGCCTACGTGGTGCAGCTGCCTCTGTATGGCGTGATCGACACCCCCTGCTGGAAGCTGCACACCAGCCCCCTGTGCACCACCAACACCAAGGAGGGCAGCAACATCTGCCTGACCCGCACCGACCGCGGCTGGTACTGTGACAACGCCGGCAGCGTGAGCTTCTTCCCCCTGGCAGAGACCTGCAAGGTGCAGAGCAACAGAGTGTTCTGTGACACCATGAACAGCCTGACCCTGCCCAGCGAGGTGAACCTGTGCAATATTGACATTTTCAACCCCAAGTATGACTGCAAGATCATGACCAGCAAGACCGACGTGAGCAGCAGTGTGATCACCTCTCTGGGCGCCATCGTCAGCTGCTACGGGAAGACCAAGTGCACCGCCAGCAACAAGAACCGCGGCATCATCAAGACCTTCAGCAACGGCTGCGACTACGTGAGCAACAAGGGCGTGGACACCGTGAGCGTGGGCAACACCCTGTACTACGTGAACAAGCAGGAAGGCAAGAGCCTGTACGTGAAGGGAGAGCCCATCATCAACTTCTACGACCCCCTGGTGTTCCCCAGTGACGAGTTCGACGCCAGCATCAGCCAGGTGAACGAGAAGATCAACGGCAGCCTGGCCTTCATCCGCAAGAGCGATGAGAAGCTGCACAATGTGGAGGACAAGATTGAGGAGATCCTGAGCAAGATCTACCACATCGAGAACGAGATCGCCCGCATCAAGAAGCTCATTGGAGAGGCT 3 | -------------------------------------------------------------------------------- /data/example_data/codon_seq.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FPPGroup/CodonBERT/cd0a88c8655579b1f01d11d3e7de7777b8bbf0c4/data/example_data/codon_seq.npy -------------------------------------------------------------------------------- /data/example_data/predict.fasta: -------------------------------------------------------------------------------- 1 | >RSV-Pre-F-GCN4 2 | ATGGAGCTGCTGATCCTGAAGACCAACGCCATCACCGCCATCCTGGCCGCGGTGACCCTGTGTTTCGCCAGCAGCCAGAACATCACCGAGGAGTTCTACCAGAGCACCTGCAGCGCGGTGAGCAAGGGCTACCTGAGCGCCCTGCGCACAGGCTGGTACACATCTGTGATCACCATCGAGCTGAGCAACATCAAGGAGAACAAGTGCAACGGCACCGACGCCAAGGTGAAGCTGATCAAGCAGGAGCTGGACAAGTACAAGTCGGCCGTGACTGAGCTGCAGCTGCTGATGCAGAGCACCCCCGCCACCAACAACAAGTTCCTGGGCTTCCTGCAGGGCGTGGGCAGCGCCATCGCCAGTGGCATCGCGGTCAGCAAAGTGCTGCACCTGGAAGGAGAAGTGAACAAGATCAAGTCTGCCCTGCTGAGCACCAACAAGGCTGTGGTGTCTCTGAGCAACGGCGTGAGCGTGCTGACCAGCAAGGTGCTGGACCTGAAGAACTACATCGACAAGCAGCTGCTGCCCATCGTGAACAAGCAGAGCTGCAGCATCAGCAACATCGAGACCGTGATCGAGTTTCAGCAGAAGAACAACCGCCTGCTGGAGATCACCCGCGAGTTCAGTGTGAACGCTGGAGTGACCACCCCTGTGTCCACCTACATGCTGACCAACAGCGAGCTGCTGAGCCTGATCAACGACATGCCCATCACCAACGACCAGAAGAAGCTGATGAGCAACAACGTGCAGATCGTGCGGCAGCAGAGCTACAGCATCATGAGCATCATCAAGGAGGAGGTGCTGGCCTACGTGGTGCAGCTGCCTCTGTATGGCGTGATCGACACCCCCTGCTGGAAGCTGCACACCAGCCCCCTGTGCACCACCAACACCAAGGAGGGCAGCAACATCTGCCTGACCCGCACCGACCGCGGCTGGTACTGTGACAACGCCGGCAGCGTGAGCTTCTTCCCCCTGGCAGAGACCTGCAAGGTGCAGAGCAACAGAGTGTTCTGTGACACCATGAACAGCCTGACCCTGCCCAGCGAGGTGAACCTGTGCAATATTGACATTTTCAACCCCAAGTATGACTGCAAGATCATGACCAGCAAGACCGACGTGAGCAGCAGTGTGATCACCTCTCTGGGCGCCATCGTCAGCTGCTACGGGAAGACCAAGTGCACCGCCAGCAACAAGAACCGCGGCATCATCAAGACCTTCAGCAACGGCTGCGACTACGTGAGCAACAAGGGCGTGGACACCGTGAGCGTGGGCAACACCCTGTACTACGTGAACAAGCAGGAAGGCAAGAGCCTGTACGTGAAGGGAGAGCCCATCATCAACTTCTACGACCCCCTGGTGTTCCCCAGTGACGAGTTCGACGCCAGCATCAGCCAGGTGAACGAGAAGATCAACGGCAGCCTGGCCTTCATCCGCAAGAGCGATGAGAAGCTGCACAATGTGGAGGACAAGATTGAGGAGATCCTGAGCAAGATCTACCACATCGAGAACGAGATCGCCCGCATCAAGAAGCTCATTGGAGAGGCT 3 | -------------------------------------------------------------------------------- /data/example_data/protein_seq.fasta: -------------------------------------------------------------------------------- 1 | >RSV-Pre-F-GCN4 2 | MELLILKTNAITAILAAVTLCFASSQNITEEFYQSTCSAVSKGYLSALRTGWYTSVITIELSNIKENKCNGTDAKVKLIKQELDKYKSAVTELQLLMQSTPATNNKFLGFLQGVGSAIASGIAVSKVLHLEGEVNKIKSALLSTNKAVVSLSNGVSVLTSKVLDLKNYIDKQLLPIVNKQSCSISNIETVIEFQQKNNRLLEITREFSVNAGVTTPVSTYMLTNSELLSLINDMPITNDQKKLMSNNVQIVRQQSYSIMSIIKEEVLAYVVQLPLYGVIDTPCWKLHTSPLCTTNTKEGSNICLTRTDRGWYCDNAGSVSFFPLAETCKVQSNRVFCDTMNSLTLPSEVNLCNIDIFNPKYDCKIMTSKTDVSSSVITSLGAIVSCYGKTKCTASNKNRGIIKTFSNGCDYVSNKGVDTVSVGNTLYYVNKQEGKSLYVKGEPIINFYDPLVFPSDEFDASISQVNEKINGSLAFIRKSDEKLHNVEDKIEEILSKIYHIENEIARIKKLIGEA -------------------------------------------------------------------------------- /figures/2024.01.24-figure1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FPPGroup/CodonBERT/cd0a88c8655579b1f01d11d3e7de7777b8bbf0c4/figures/2024.01.24-figure1.jpg -------------------------------------------------------------------------------- /models/kidney_1_1_CodonBert_model_20230726_320_model_param.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FPPGroup/CodonBERT/cd0a88c8655579b1f01d11d3e7de7777b8bbf0c4/models/kidney_1_1_CodonBert_model_20230726_320_model_param.pt -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | import argparse 6 | from scripts.config_function import * 7 | from scripts.codon_bert_pytorch import CodonBERT 8 | 9 | def set_args(): 10 | parser = argparse.ArgumentParser(description='CodonBert: mRNA optimizaton') 11 | parser.add_argument('-m', '--model', help='the model path', required=True, type=str) 12 | parser.add_argument('-f', '--fasta', help='the seq fasta',required = True, type = str) 13 | parser.add_argument('-o', '--output_path', help='the save path',required = True, type = str) 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | if __name__ == '__main__': 19 | args = set_args() 20 | model_path = args.model 21 | seq_file = args.fasta 22 | model_output_save_path = args.output_path 23 | model_output_fix_save_path = args.output_path.split('.')[0]+"_fix.fasta" 24 | 25 | Seq_name = [] 26 | Seq = [] 27 | seq_num = 0 28 | for seqName,seq in readFa(seq_file): 29 | Seq_name.append(seqName) 30 | Seq.append(seq) 31 | seqLen = len(seq) 32 | seq_num += 1 33 | AA_seq_dataset = np.array(Seq, dtype=object) 34 | AA_int_list = [] 35 | for idx in range(len(AA_seq_dataset)): 36 | AA_seq = tokenize_seq(AA_seq_dataset[idx], seq_len, aa_to_token_index, additional_token_to_index) 37 | AA_int_list.append(AA_seq) 38 | 39 | test_AA_int = np.array(AA_int_list) 40 | test_mask = create_mask(AA_seq_dataset, seq_len) 41 | test_AA_int = test_AA_int.astype(int) 42 | test_DNA_int = torch.zeros((test_AA_int.shape[0], test_AA_int.shape[1])) 43 | test_mask_int = test_mask.astype(int) 44 | test_AA_float_tensor = torch.tensor(test_AA_int, dtype=torch.int64) 45 | test_DNA_float_tensor = torch.tensor(test_DNA_int, dtype=torch.int64) 46 | test_mask_bool_tensor = torch.tensor(test_mask_int, dtype=torch.bool) 47 | test_torch_dataset = data.TensorDataset(test_AA_float_tensor, test_DNA_float_tensor, test_mask_bool_tensor) 48 | test_train_iter = data.DataLoader(dataset=test_torch_dataset, batch_size=1, shuffle=False, num_workers=2) 49 | 50 | model = CodonBERT( 51 | num_tokens = 25, 52 | num_annotation_class = 67, 53 | dim = 128, 54 | dim_global = 128, 55 | depth = 6, 56 | narrow_conv_kernel = 9, 57 | wide_conv_kernel = 9, 58 | wide_conv_dilation = 5, 59 | attn_heads = 8, 60 | attn_dim_head = 64, 61 | local_to_global_attn = False, 62 | local_self_attn = True, 63 | num_global_tokens = 2, 64 | glu_conv = False 65 | ) 66 | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)) 67 | 68 | AA_seq_result_all = [] 69 | DNA_annotation_result_all = [] 70 | AA_annotation_result_all = [] 71 | DNA_input_all = [] 72 | AA_input_all = [] 73 | DNA_acc_all = [] 74 | AA_acc_all = [] 75 | GC_con_all = [] 76 | for ind, (src, trg, mask) in enumerate(test_train_iter): 77 | zero_trg = trg.clone() 78 | zero_trg[:,:]=0 79 | seq_logits, annotation_logits = model(src, zero_trg, mask = mask) 80 | seq_logits = seq_logits[mask] 81 | seq_labels = src[mask] 82 | annotation_logits = annotation_logits[mask] 83 | annotation_labels = trg[mask] 84 | seq_logits_array = np.array(seq_logits.detach().numpy()) 85 | result_seq = np.argmax(seq_logits_array, axis=1) 86 | annotation_logits_array = np.array(annotation_logits.detach().numpy()) 87 | result_annotation = np.argmax(annotation_logits_array, axis=1) 88 | AA_seq_result = seq_pre_to_AA(result_seq) 89 | DNA_annotation_result, AA_annotation_result = annotation_pre_to_AA(result_annotation) 90 | AA_input = encode_AA_to_AA(seq_labels) 91 | AA_encode_annotation_result = AA_to_encode_AA(AA_annotation_result) 92 | same_number = 0 93 | single_acc = 0 94 | for idx in range(len(AA_encode_annotation_result)): 95 | if seq_labels[idx] == AA_encode_annotation_result[idx]: 96 | same_number = same_number + 1 97 | single_acc = same_number/len(AA_encode_annotation_result) 98 | AA_acc_all.append(single_acc) 99 | GC_con_single = GC_con(DNA_annotation_result) 100 | GC_con_all.append(GC_con_single) 101 | AA_seq_result_all.append(AA_seq_result) 102 | DNA_annotation_result_all.append(DNA_annotation_result) 103 | AA_annotation_result_all.append(AA_annotation_result) 104 | AA_input_all.append(AA_input) 105 | 106 | fasta_path = model_output_save_path 107 | fp = open(fasta_path,'w') 108 | for i in range(len(DNA_annotation_result_all)): 109 | seq_name = '>'+Seq_name[i]+"\n" 110 | seq = DNA_annotation_result_all[i] 111 | fp.write(seq_name) 112 | fp.write(seq+"\n") 113 | fp.close() 114 | 115 | AA_input_compare_AA_annotation_result = [] 116 | AA_input_compare_AA_annotation_result_detail = [] 117 | AA_change = [[] for i in range (len(AA_input_all))] 118 | AA_change_idx = [[] for i in range (len(AA_input_all))] 119 | AA_acc_metri = 0 120 | length_all = 0 121 | false_number = 0 122 | for i in range(0, len(AA_input_all)): 123 | length_all = length_all + len(AA_input_all[i]) 124 | if AA_input_all[i] == AA_annotation_result_all[i]: 125 | AA_input_compare_AA_annotation_result.append(True) 126 | AA_input_compare_AA_annotation_result_detail.append(None) 127 | else: 128 | AA_input_compare_AA_annotation_result.append(False) 129 | single_detail = [] 130 | for j in range(0, len(AA_input_all[i])): 131 | if AA_input_all[i][j]!=AA_annotation_result_all[i][j]: 132 | AA_change[i].append(AA_input_all[i][j]) 133 | AA_change_idx[i].append(j) 134 | false_number = false_number + 1 135 | single_detail.append(str(j)+':'+str(AA_input_all[i][j])+'_'+str(AA_annotation_result_all[i][j])) 136 | AA_input_compare_AA_annotation_result_detail.append(single_detail) 137 | AA_acc_metri = 1-(false_number/length_all) 138 | 139 | fix_DNA_annotation_result_all = [[] for i in range (len(DNA_annotation_result_all))] 140 | 141 | for seq_idx in range(len(AA_change_idx)): 142 | result_DNA_seq = DNA_annotation_result_all[seq_idx] 143 | fix_DNA_one_codon = result_DNA_seq 144 | for change_idx in range(len(AA_change_idx[seq_idx])): 145 | false_local = AA_change_idx[seq_idx][change_idx] 146 | input_AA_true = AA_input_all[seq_idx][AA_change_idx[seq_idx][change_idx]] 147 | fix_codon = fix_AA_codon[input_AA_true] 148 | if false_local == 0: 149 | fix_DNA_one_codon = fix_codon + fix_DNA_one_codon[((false_local+1)*3):] 150 | else: 151 | fix_DNA_one_codon = fix_DNA_one_codon[:((false_local*3))] + fix_codon + fix_DNA_one_codon[((false_local+1)*3):] 152 | fix_DNA_annotation_result_all[seq_idx].append(fix_DNA_one_codon) 153 | 154 | 155 | fasta_path = model_output_fix_save_path 156 | fp = open(fasta_path,'w') 157 | for i in range(len(fix_DNA_annotation_result_all)): 158 | seq_name = '>'+Seq_name[i]+"\n" 159 | seq = fix_DNA_annotation_result_all[i] 160 | fp.write(seq_name) 161 | fp.write(seq[0]+"\n") 162 | fp.close() 163 | 164 | -------------------------------------------------------------------------------- /scripts/codon_bert_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import torch.nn.functional as FC 5 | from torch import nn, einsum 6 | from einops.layers.torch import Rearrange, Reduce 7 | from einops import rearrange, repeat 8 | from .config_function import * 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def max_neg_value(t): 15 | return -torch.finfo(t.dtype).max 16 | 17 | class Residual(nn.Module): 18 | def __init__(self, fn): 19 | super().__init__() 20 | self.fn = fn 21 | def forward(self, x): 22 | return self.fn(x) + x 23 | 24 | class GlobalLinearSelfAttention(nn.Module): 25 | def __init__( 26 | self, 27 | *, 28 | dim, 29 | dim_head, 30 | heads 31 | ): 32 | super().__init__() 33 | inner_dim = dim_head * heads 34 | self.heads = heads 35 | self.scale = dim_head ** -0.5 36 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 37 | self.to_out = nn.Linear(inner_dim, dim) 38 | 39 | def forward(self, feats, mask = None): 40 | h = self.heads 41 | q, k, v = self.to_qkv(feats).chunk(3, dim = -1) 42 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 43 | 44 | if exists(mask): 45 | mask = rearrange(mask, 'b n -> b () n ()') 46 | k = k.masked_fill(~mask, -torch.finfo(k.dtype).max) 47 | 48 | q = q.softmax(dim = -1) 49 | k = k.softmax(dim = -2) 50 | q = q * self.scale 51 | 52 | if exists(mask): 53 | v = v.masked_fill(~mask, 0.) 54 | 55 | context = einsum('b h n d, b h n e -> b h d e', k, v) 56 | out = einsum('b h d e, b h n d -> b h n e', context, q) 57 | out = rearrange(out, 'b h n d -> b n (h d)') 58 | return self.to_out(out) 59 | 60 | class CrossAttention(nn.Module): 61 | def __init__( 62 | self, 63 | *, 64 | dim, 65 | dim_keys, 66 | dim_out, 67 | heads, 68 | dim_head = 64, 69 | qk_activation = nn.Tanh() 70 | ): 71 | super().__init__() 72 | self.heads = heads 73 | self.scale = dim_head ** -0.5 74 | inner_dim = dim_head * heads 75 | 76 | self.qk_activation = qk_activation 77 | 78 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 79 | self.to_kv = nn.Linear(dim_keys, inner_dim * 2, bias = False) 80 | self.to_out = nn.Linear(inner_dim, dim_out) 81 | 82 | self.null_key = nn.Parameter(torch.randn(dim_head)) 83 | self.null_value = nn.Parameter(torch.randn(dim_head)) 84 | 85 | def forward(self, x, context, mask = None, context_mask = None): 86 | b, h, device = x.shape[0], self.heads, x.device 87 | 88 | q = self.to_q(x) 89 | k, v = self.to_kv(context).chunk(2, dim = -1) 90 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 91 | 92 | null_k, null_v = map(lambda t: repeat(t, 'd -> b h () d', b = b, h = h), (self.null_key, self.null_value)) 93 | k = torch.cat((null_k, k), dim = -2) 94 | v = torch.cat((null_v, v), dim = -2) 95 | 96 | q, k = map(lambda t: self.qk_activation(t), (q, k)) 97 | 98 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 99 | 100 | if exists(mask) or exists(context_mask): 101 | i, j = sim.shape[-2:] 102 | 103 | if not exists(mask): 104 | mask = torch.ones(b, i, dtype = torch.bool, device = device) 105 | 106 | if exists(context_mask): 107 | context_mask = FC.pad(context_mask, (1, 0), value = True) 108 | else: 109 | context_mask = torch.ones(b, j, dtype = torch.bool, device = device) 110 | 111 | mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j') 112 | sim.masked_fill_(~mask, max_neg_value(sim)) 113 | 114 | attn = sim.softmax(dim = -1) 115 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 116 | out = rearrange(out, 'b h n d -> b n (h d)') 117 | return self.to_out(out) 118 | 119 | class Layer(nn.Module): 120 | def __init__( 121 | self, 122 | *, 123 | dim, 124 | dim_global, 125 | narrow_conv_kernel = 9, 126 | wide_conv_kernel = 9, 127 | wide_conv_dilation = 5, 128 | attn_heads = 8, 129 | attn_dim_head = 64, 130 | attn_qk_activation = nn.Tanh(), 131 | local_to_global_attn = False, 132 | local_self_attn = False, 133 | glu_conv = False 134 | ): 135 | super().__init__() 136 | 137 | self.seq_self_attn = GlobalLinearSelfAttention(dim = dim, dim_head = attn_dim_head, heads = attn_heads) if local_self_attn else None 138 | 139 | conv_mult = 2 if glu_conv else 1 140 | 141 | self.narrow_conv = nn.Sequential( 142 | nn.Conv1d(dim, dim * conv_mult, narrow_conv_kernel, padding = narrow_conv_kernel // 2), 143 | nn.GELU() if not glu_conv else nn.GLU(dim = 1) 144 | ) 145 | 146 | wide_conv_padding = (wide_conv_kernel + (wide_conv_kernel - 1) * (wide_conv_dilation - 1)) // 2 147 | 148 | self.wide_conv = nn.Sequential( 149 | nn.Conv1d(dim, dim * conv_mult, wide_conv_kernel, dilation = wide_conv_dilation, padding = wide_conv_padding), 150 | nn.GELU() if not glu_conv else nn.GLU(dim = 1) 151 | ) 152 | 153 | self.global_narrow_conv = nn.Sequential( 154 | nn.Conv1d(dim_global, dim_global * conv_mult, narrow_conv_kernel, padding = narrow_conv_kernel // 2), 155 | nn.GELU() if not glu_conv else nn.GLU(dim_global = 1) 156 | ) 157 | 158 | global_wide_conv_padding = (wide_conv_kernel + (wide_conv_kernel - 1) * (wide_conv_dilation - 1)) // 2 159 | 160 | self.global_wide_conv = nn.Sequential( 161 | nn.Conv1d(dim_global, dim_global * conv_mult, wide_conv_kernel, dilation = wide_conv_dilation, padding = global_wide_conv_padding), 162 | nn.GELU() if not glu_conv else nn.GLU(dim_global = 1) 163 | ) 164 | 165 | self.local_to_global_attn = local_to_global_attn 166 | 167 | if local_to_global_attn: 168 | self.extract_global_info = CrossAttention( 169 | dim = dim, 170 | dim_keys = dim_global, 171 | dim_out = dim, 172 | heads = attn_heads, 173 | dim_head = attn_dim_head 174 | ) 175 | else: 176 | self.extract_global_info = nn.Sequential( 177 | Reduce('b n d -> b d', 'mean'), 178 | nn.Linear(dim_global, dim), 179 | nn.GELU(), 180 | Rearrange('b d -> b () d') 181 | ) 182 | 183 | self.local_norm = nn.LayerNorm(dim) 184 | 185 | self.local_feedforward = nn.Sequential( 186 | Residual(nn.Sequential( 187 | nn.Linear(dim, dim), 188 | nn.GELU(), 189 | )), 190 | nn.LayerNorm(dim) 191 | ) 192 | 193 | self.global_attend_local = CrossAttention(dim = dim_global, dim_out = dim_global, dim_keys = dim, heads = attn_heads, dim_head = attn_dim_head, qk_activation = attn_qk_activation) 194 | 195 | self.global_dense = nn.Sequential( 196 | nn.Linear(dim_global, dim_global), 197 | nn.GELU() 198 | ) 199 | 200 | self.global_norm = nn.LayerNorm(dim_global) 201 | 202 | self.global_feedforward = nn.Sequential( 203 | Residual(nn.Sequential( 204 | nn.Linear(dim_global, dim_global), 205 | nn.GELU() 206 | )), 207 | nn.LayerNorm(dim_global), 208 | ) 209 | 210 | def forward(self, tokens, annotation, mask = None): 211 | if self.local_to_global_attn: 212 | global_info = self.extract_global_info(tokens, annotation, mask = mask) 213 | else: 214 | global_info = self.extract_global_info(annotation) 215 | 216 | global_linear_attn = self.seq_self_attn(tokens) if exists(self.seq_self_attn) else 0 217 | 218 | conv_input = rearrange(tokens, 'b n d -> b d n') 219 | 220 | if exists(mask): 221 | conv_input_mask = rearrange(mask, 'b n -> b () n') 222 | conv_input = conv_input.masked_fill(~conv_input_mask, 0.) 223 | 224 | narrow_out = self.narrow_conv(conv_input) 225 | narrow_out = rearrange(narrow_out, 'b d n -> b n d') 226 | wide_out = self.wide_conv(conv_input) 227 | wide_out = rearrange(wide_out, 'b d n -> b n d') 228 | 229 | tokens = tokens + narrow_out + wide_out + global_info + global_linear_attn 230 | tokens = self.local_norm(tokens) 231 | 232 | tokens = self.local_feedforward(tokens) 233 | 234 | 235 | annotation = self.global_attend_local(tokens, annotation, context_mask = mask) 236 | global_conv_input = rearrange(annotation, 'b n d -> b d n') 237 | 238 | if exists(mask): 239 | conv_input_mask = rearrange(mask, 'b n -> b () n') 240 | global_conv_input = global_conv_input.masked_fill(~conv_input_mask, 0.) 241 | 242 | global_narrow_out = self.global_narrow_conv(global_conv_input) 243 | global_narrow_out = rearrange(global_narrow_out, 'b d n -> b n d') 244 | global_wide_out = self.global_narrow_conv(global_conv_input) 245 | global_wide_out = rearrange(global_wide_out, 'b d n -> b n d') 246 | b2_out = global_narrow_out + global_wide_out 247 | annotation = self.global_norm(b2_out) 248 | annotation = self.global_feedforward(annotation) 249 | 250 | return tokens, annotation 251 | 252 | 253 | class CodonBERT(nn.Module): 254 | def __init__( 255 | self, 256 | *, 257 | num_tokens = 24, 258 | num_annotation = 8943, 259 | num_annotation_class = 65, 260 | dim = 128, 261 | dim_global = 128, 262 | depth = 6, 263 | narrow_conv_kernel = 9, 264 | wide_conv_kernel = 9, 265 | wide_conv_dilation = 5, 266 | attn_heads = 8, 267 | attn_dim_head = 64, 268 | attn_qk_activation = nn.Tanh(), 269 | local_to_global_attn = False, 270 | local_self_attn = False, 271 | num_global_tokens = 1, 272 | glu_conv = False 273 | ): 274 | super().__init__() 275 | self.num_tokens = num_tokens 276 | self.token_emb = nn.Embedding(num_tokens, dim) 277 | 278 | self.num_annotation_class = num_annotation_class 279 | self.to_global_emb = nn.Embedding(num_annotation_class, dim_global) 280 | 281 | self.layers = nn.ModuleList([Layer(dim = dim, dim_global = dim_global, narrow_conv_kernel = narrow_conv_kernel, wide_conv_dilation = wide_conv_dilation, wide_conv_kernel = wide_conv_kernel, attn_qk_activation = attn_qk_activation, local_to_global_attn = local_to_global_attn, local_self_attn = local_self_attn, glu_conv = glu_conv) for layer in range(depth)]) 282 | 283 | self.to_token_logits = nn.Linear(dim, num_tokens) 284 | self.to_annotation_logits = nn.Linear(dim_global, num_annotation_class) 285 | 286 | def forward(self, seq, annotation, mask = None): 287 | tokens = self.token_emb(seq) 288 | annotation = self.to_global_emb(annotation) 289 | for layer in self.layers: 290 | tokens, annotation = layer(tokens, annotation, mask = mask) 291 | tokens = self.to_token_logits(tokens) 292 | annotation = self.to_annotation_logits(annotation) 293 | return tokens, annotation 294 | 295 | def get_mask_subset_with_prob(mask, prob): 296 | batch, seq_len, device = *mask.shape, mask.device 297 | max_masked = math.ceil(prob * seq_len) 298 | 299 | num_tokens = mask.sum(dim=-1, keepdim=True) 300 | mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) 301 | mask_excess = mask_excess[:, :max_masked] 302 | 303 | rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) 304 | _, sampled_indices = rand.topk(max_masked, dim=-1) 305 | sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) 306 | 307 | new_mask = torch.zeros((batch, seq_len + 1), device=device) 308 | new_mask.scatter_(-1, sampled_indices, 1) 309 | return new_mask[:, 1:].bool() 310 | 311 | 312 | class PretrainingWrapper(nn.Module): 313 | def __init__( 314 | self, 315 | model, 316 | seq_length = 512, 317 | random_replace_token_prob = 0.1, 318 | remove_annotation_prob = 0.25, 319 | add_annotation_prob = 0.01, 320 | remove_all_annotations_prob = 0.5, 321 | seq_loss_weight = 1., 322 | annotation_loss_weight = 1., 323 | exclude_token_ids = (0, 1, 2), 324 | RNA_exclude_token_ids = (0, 1, 2) 325 | ): 326 | super().__init__() 327 | assert isinstance(model, CodonBERT), 'model must be an instance of CodonBERT' 328 | 329 | self.model = model 330 | 331 | self.seq_length = seq_length 332 | 333 | self.random_replace_token_prob = random_replace_token_prob 334 | self.remove_annotation_prob = remove_annotation_prob 335 | self.add_annotation_prob = add_annotation_prob 336 | self.remove_all_annotations_prob = remove_all_annotations_prob 337 | 338 | self.seq_loss_weight = seq_loss_weight 339 | self.annotation_loss_weight = annotation_loss_weight 340 | 341 | self.exclude_token_ids = exclude_token_ids 342 | self.RNA_exclude_token_ids = RNA_exclude_token_ids 343 | 344 | def forward(self, seq, annotation, epoch, mask = None): 345 | batch_size, device = seq.shape[0], seq.device 346 | 347 | seq_labels = seq 348 | annotation_labels = annotation 349 | 350 | if not exists(mask): 351 | mask = torch.ones_like(seq).bool() 352 | 353 | excluded_tokens_mask = mask 354 | 355 | 356 | for token_id in self.exclude_token_ids: 357 | AA_excluded_tokens_mask = excluded_tokens_mask & (seq != token_id) 358 | for token_id in self.RNA_exclude_token_ids: 359 | RNA_excluded_tokens_mask = excluded_tokens_mask & (annotation != token_id) 360 | 361 | random_replace_token_prob_mask = get_mask_subset_with_prob(AA_excluded_tokens_mask, self.random_replace_token_prob) 362 | random_tokens=torch.zeros_like(seq) 363 | noised_seq = torch.where(random_replace_token_prob_mask, random_tokens, seq) 364 | RNA_random_replace_token_prob_mask = get_mask_subset_with_prob(RNA_excluded_tokens_mask, self.random_replace_token_prob) 365 | RNA_random_tokens=torch.zeros_like(annotation) 366 | noised_annotation = torch.where(RNA_random_replace_token_prob_mask, RNA_random_tokens, annotation) 367 | 368 | if epoch < 15: 369 | pass 370 | elif epoch < 300: 371 | mask_ratio = min((epoch // 15) * 0.05, 0.95) 372 | index_list = random.sample(range(0, noised_annotation.shape[0] - 1), int(noised_annotation.shape[0] * mask_ratio)) 373 | noised_annotation[index_list, :] = 0 374 | else: 375 | noised_annotation[:, :] = 0 376 | 377 | seq_logits, annotation_logits = self.model(noised_seq, noised_annotation, mask = mask) 378 | seq_logits = seq_logits[mask] 379 | seq_labels = seq_labels[mask] 380 | annotation_logits = annotation_logits[mask] 381 | annotation_labels = annotation_labels[mask] 382 | 383 | seq_loss = FC.cross_entropy(seq_logits, seq_labels, reduction = 'mean') 384 | annotation_loss = FC.cross_entropy(annotation_logits, annotation_labels, reduction = 'mean') 385 | 386 | return seq_loss * self.seq_loss_weight + annotation_loss * self.annotation_loss_weight, seq_loss, annotation_loss 387 | -------------------------------------------------------------------------------- /scripts/config_function.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import numpy as np 4 | from tensorboardX import SummaryWriter 5 | import torch.nn.functional as FC 6 | import torch 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | dict_raw_int = {'A': 3, 'B': 4, 'C': 5, 'D': 6, 'E': 7, 'F': 8, 'G': 9, 'H': 10, 'I': 11, 'J': 12, 'K': 13, 'L': 14, 'M': 15, 'N': 16, 'O': 17, 'P': 18, 'Q': 19, 'R': 20, 'S': 21, 'T': 22, 'U': 23, 'V': 24, 'W': 25, 'X': 26, 'Y': 27, 'Z': 28, 'a': 29, 'b': 30, 'c': 31, 'd': 32, 'e': 33, 'f': 34, 'g': 35, 'h': 36, 'i': 37, 'j': 38, 'k': 39, 'l': 40, 'm': 41, 'n': 42, 'o': 43, 'p': 44, 'q': 45, 'r': 46, 's': 47, 't': 48, 'u': 49, 'v': 50, 'w': 51, 'x': 52, 'y': 53, 'z': 54, '0': 55, '1': 56, '2': 57, '3': 58, '4': 59, '5': 60, '6': 61, '7': 62, '8': 63, '9': 64, ':': 65, ';': 66} 11 | AA_dict_raw_int = {'A': 3, 'C': 4, 'D': 5, 'E': 6, 'F': 7, 'G': 8, 'H': 9, 'I': 10, 'K': 11, 'L': 12, 'M': 13, 'N': 14, 'P': 15, 'Q': 16, 'R': 17, 'S': 18, 'T': 19, 'U': 20, 'V': 21, 'W': 22, 'X': 23, 'Y': 24} 12 | codon_int = {'GGT':'a','GGC':'b','GGA':'c','GGG':'d','GCT':'e','GCC':'f','GCA':'g','GCG':'h','GTT':'i','GTC':'j','GTA':'k','GTG':'m','CTT':'l','CTC':'n','CTA':'o','CTG':'p','TTA':'q','TTG':'r','ATT':'s','ATC':'t','ATA':'u', 13 | 'CCT':'v','CCA':'w','CCG':'x','CCC':'y','TTT':'z','TTC':'A','TAT':'B','TAC':'C','TGG':'D','TCT':'E','TCA':'F','TCC':'G','TCG':'H','AGT':'I','AGC':'J', 14 | 'ACT':'K','ACC':'M','ACG':'L','ACA':'N','ATG':'O','TGT':'P','TGC':'Q','AAT':'R','AAC':'S','CAA':'T','CAG':'U','GAT':'V','GAC':'W', 15 | 'GAA':'X','GAG':'Y','AAA':'Z','AAG':'1','CGT':'2','CGC':'3','CGG':'4','CGA':'5','AGA':'6','AGG':'7','CAT':'8','CAC':'9','TAA':'0','TAG':':','TGA':';'} 16 | int_codon = dict((value, cod) for cod, value in codon_int.items()) 17 | G = ['GGT','GGC','GGA','GGG'] 18 | A = ['GCT','GCC','GCA','GCG'] 19 | V = ['GTT','GTC','GTA','GTG'] 20 | L = ['CTT','CTC','CTA','CTG','TTA','TTG'] 21 | I = ['ATT','ATC','ATA'] 22 | P = ['CCT','CCA','CCG','CCC'] 23 | F = ['TTT','TTC'] 24 | Y = ['TAT','TAC'] 25 | W = ['TGG'] 26 | S = ['TCT','TCA','TCC','TCG','AGT','AGC'] 27 | T = ['ACT','ACC','ACG','ACA'] 28 | M = ['ATG'] 29 | C = ['TGT','TGC'] 30 | N = ['AAT','AAC'] 31 | Q = ['CAA','CAG'] 32 | D = ['GAT','GAC'] 33 | E = ['GAA','GAG'] 34 | K = ['AAA','AAG'] 35 | R = ['CGT','CGC','CGG','CGA','AGA','AGG'] 36 | H = ['CAT','CAC'] 37 | X = ['TAA','TAG','TGA'] 38 | homonym_codon = {'G':G,'A':A,'V':V,'L':L,'I':I,'P':P,'F':F,'Y':Y,'W':W,'S':S,'T':T,'M':M,'C':C,'N':N,'Q':Q,'D':D,'E':E,'K':K,'R':R,'H':H,'X':X} 39 | fix_AA_codon = {'G':'GGC','A':'GCC','V':'GTC','L':'CTG','I':'ATC','P':'CCG','F':'TTC','Y':'TAC','W':'TGG','S':'AGC','T':'ACG','M':'ATG','C':'TGC','N':'AAC','Q':'CAG','D':'GAC','E':'GAG','K':'AAG','R':'CGC','H':'CAC','X':'TAG'} 40 | 41 | 42 | seq_len = 2048 43 | N_EPOCHS = 320 44 | lr_init = 3e-4 45 | batch_size = 12 46 | is_binary = False 47 | CodonBERT_path = '.' 48 | writer = SummaryWriter(CodonBERT_path+'logs/2023_0726/train_epoch320-batch12') 49 | valid_writer = SummaryWriter(CodonBERT_path+'logs/2023_0726/valid_epoch320-batch12') 50 | ACC_writer = SummaryWriter(CodonBERT_path+'logs/2023_0726/acc_epoch320-batch1') 51 | tensorboard_ind = 0 52 | valid_tensorboard_ind = 0 53 | ACC_tensorboard_ind = 0 54 | 55 | 56 | global device 57 | if torch.cuda.is_available(): 58 | device = 'cuda' 59 | else: 60 | device = 'cpu' 61 | 62 | 63 | def convert_list_to_dict(dict_raw_int,value): 64 | return [k for k, v in dict_raw_int.items() if v == (value)] 65 | 66 | def DNA_to_AA(DNA_seq): 67 | AA_list = "" 68 | start = 0 69 | end = 3 70 | DNA_seq = DNA_seq.replace('U','T') 71 | while(end<=len(DNA_seq)+1): 72 | codon = DNA_seq[start:end] 73 | start+=3 74 | end+=3 75 | for AA,codons in homonym_codon.items(): 76 | if codon in codons: 77 | AA_list += AA 78 | return AA_list 79 | 80 | def seq_pre_to_AA(pre_list): 81 | AA_result = '' 82 | for i in range(len(pre_list)): 83 | AA_result=AA_result + str(convert_list_to_dict(AA_dict_raw_int,pre_list[i])).replace('[','').replace(']','').replace('\'','') 84 | return AA_result 85 | 86 | def annotation_pre_to_AA(pre_list): 87 | raw_label = '' 88 | for i in range(len(pre_list)): 89 | raw_label=raw_label + str(convert_list_to_dict(dict_raw_int,pre_list[i])).replace('[','').replace(']','').replace('\'','') 90 | DNA_label = '' 91 | for i in range(len(raw_label)): 92 | DNA_label=DNA_label + str(convert_list_to_dict(codon_int,raw_label[i])).replace('[','').replace(']','').replace('\'','') 93 | AA_result = DNA_to_AA(DNA_label) 94 | return DNA_label, AA_result 95 | 96 | def encode_AA_to_AA(pre_list): 97 | raw_label = '' 98 | for i in range(len(pre_list)): 99 | raw_label=raw_label + str(convert_list_to_dict(AA_dict_raw_int,pre_list[i])).replace('[','').replace(']','').replace('\'','') 100 | return raw_label 101 | 102 | def AA_to_encode_AA(AA_list): 103 | encode_AA_label = [] 104 | for i in range(len(AA_list)): 105 | encode_AA_label.append(AA_dict_raw_int[AA_list[i]]) 106 | return encode_AA_label 107 | 108 | def GC_con(seq): 109 | length = len(seq) 110 | G_num = seq.count('G') 111 | C_num = seq.count('C') 112 | GC_content = (G_num+C_num)/length 113 | GC_content = round(GC_content, 4) 114 | return GC_content 115 | 116 | ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY' 117 | ADDITIONAL_TOKENS = ['', '', ''] 118 | n_aas = len(ALL_AAS) 119 | aa_to_token_index = {aa: i+3 for i, aa in enumerate(ALL_AAS)} 120 | additional_token_to_index = {token: i for i, token in enumerate(ADDITIONAL_TOKENS)} 121 | unique_labels_list = [chr(i) for i in range(65, 91)] 122 | unique_labels_list = unique_labels_list + [chr(i) for i in range(97, 123)] 123 | unique_labels_list = unique_labels_list + [chr(i) for i in range(48, 60)] 124 | label_to_index = {str(label): i+3 for i, label in enumerate(unique_labels_list)} 125 | 126 | def tokenize_seq(seq, max_len, aa_to_token_index, additional_token_to_index): 127 | coverted_seq = [additional_token_to_index['']] + [aa_to_token_index.get(aa, aa_to_token_index) for aa in seq] + [additional_token_to_index['']] 128 | output_seq = [additional_token_to_index[''] for i in range(len(coverted_seq), max_len)] 129 | return coverted_seq + output_seq 130 | 131 | def encode_seq_Y_68(seqs, seq_len, is_binary): 132 | unique_labels_list = [chr(i) for i in range(65, 91)] 133 | unique_labels_list = unique_labels_list + [chr(i) for i in range(97, 123)] 134 | unique_labels_list = unique_labels_list + [chr(i) for i in range(48, 60)] 135 | label_to_index = {str(label): i+3 for i, label in enumerate(unique_labels_list)} 136 | 137 | Y = np.zeros((len(seqs), seq_len), dtype = int) 138 | sample_weigths = np.zeros((len(seqs), seq_len)) 139 | 140 | for i, seq in enumerate(seqs): 141 | Y[i, 0] = 1 142 | Y[i, len(seq)+1] = 2 143 | for j, label in enumerate(seq): 144 | Y[i, j + 1] = label_to_index[label] 145 | for r in range(len(seq)+2,seq_len): 146 | Y[i, r] = 0 147 | sample_weigths[i, 1:(len(seq) + 1)] = 1 148 | if is_binary: 149 | Y = np.expand_dims(Y, axis = -1) 150 | sample_weigths = np.expand_dims(sample_weigths, axis = -1) 151 | return Y, sample_weigths 152 | 153 | 154 | def sample_DNA_to_int(DNA_seq,max_length): 155 | start = 0 156 | end = 3 157 | integer_encoded = [] 158 | while(end<=len(DNA_seq)+1): 159 | codon = DNA_seq[start:end] 160 | start+=3 161 | end+=3 162 | integer_encoded.append(codon_int[codon]) 163 | while(len(integer_encoded)') or not line) and seqName: 189 | yield((seqName,seq)) 190 | if line.startswith('>'): 191 | seqName = line[1:] 192 | seq='' 193 | else: 194 | seq+=line 195 | if not line: 196 | break 197 | 198 | 199 | def self_evaluate(model, test_train_iter, valid_tensorboard_ind, valid_writer): 200 | for ind, (src, trg, mask) in enumerate(test_train_iter): 201 | src = src.to(device) 202 | trg = trg.to(device) 203 | mask = mask.to(device) 204 | zero_trg = trg.clone() 205 | zero_trg[:,:]=0 206 | seq_logits, annotation_logits = model(src, zero_trg, mask = mask) 207 | seq_logits = seq_logits[mask] 208 | seq_labels = src[mask] 209 | annotation_logits = annotation_logits[mask] 210 | annotation_labels = trg[mask] 211 | valid_seq_loss = FC.cross_entropy(seq_logits, seq_labels, reduction = 'mean') 212 | valid_annotation_loss = FC.cross_entropy(annotation_logits, annotation_labels, reduction = 'mean') 213 | valid_loss = valid_seq_loss + valid_annotation_loss 214 | 215 | valid_writer.add_scalar('valid_loss',valid_loss.item(),valid_tensorboard_ind) 216 | valid_writer.add_scalar('valid_seq_loss',valid_seq_loss.item(),valid_tensorboard_ind) 217 | valid_writer.add_scalar('valid_annotation_loss',valid_annotation_loss.item(),valid_tensorboard_ind) 218 | valid_tensorboard_ind += 1 219 | return valid_tensorboard_ind, valid_writer 220 | 221 | 222 | def AA_acc(model, test_train_iter, ACC_tensorboard_ind, ACC_writer): 223 | AA_acc_all = 0 224 | for ind, (src, trg, mask) in enumerate(test_train_iter): 225 | src = src.to(device) 226 | trg = trg.to(device) 227 | mask = mask.to(device) 228 | zero_trg = trg.clone() 229 | zero_trg[:,:]=0 230 | seq_logits, annotation_logits = model(src, zero_trg, mask = mask) 231 | seq_logits = seq_logits[mask] 232 | seq_labels = src[mask] 233 | annotation_logits = annotation_logits[mask] 234 | 235 | annotation_logits_array = np.array(annotation_logits.cpu().detach().numpy()) 236 | result_annotation = np.argmax(annotation_logits_array, axis=1) 237 | DNA_annotation_result, AA_annotation_result = annotation_pre_to_AA(result_annotation) 238 | AA_encode_annotation_result = AA_to_encode_AA(AA_annotation_result) 239 | 240 | same_number = 0 241 | single_acc = 0 242 | for idx in range(len(AA_encode_annotation_result)): 243 | if seq_labels[idx] == AA_encode_annotation_result[idx]: 244 | same_number = same_number + 1 245 | single_acc = same_number/len(AA_encode_annotation_result) 246 | AA_acc_all = AA_acc_all + single_acc 247 | 248 | ACC_writer.add_scalar('AA_acc',single_acc,ACC_tensorboard_ind) 249 | ACC_tensorboard_ind += 1 250 | 251 | return ACC_tensorboard_ind, ACC_writer 252 | 253 | 254 | 255 | -------------------------------------------------------------------------------- /scripts/data_formatting.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FPPGroup/CodonBERT/cd0a88c8655579b1f01d11d3e7de7777b8bbf0c4/scripts/data_formatting.py -------------------------------------------------------------------------------- /scripts/data_processing.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from Bio.Seq import Seq 3 | from Bio import SeqIO 4 | import pandas as pd 5 | import re 6 | import seaborn as sns 7 | from collections import defaultdict 8 | import random 9 | import numpy as np 10 | from numpy import array 11 | from numpy import argmax 12 | from sklearn.preprocessing import LabelEncoder 13 | from sklearn.preprocessing import OneHotEncoder 14 | import os 15 | from tqdm import tqdm 16 | from matplotlib import pyplot as plt 17 | import sys 18 | import os 19 | from get_metrics import * 20 | import argparse 21 | 22 | # Create a homonym codon subtable 23 | G = ['GGT', 'GGC', 'GGA', 'GGG'] 24 | A = ['GCT', 'GCC', 'GCA', 'GCG'] 25 | V = ['GTT', 'GTC', 'GTA', 'GTG'] 26 | L = ['CTT', 'CTC', 'CTA', 'CTG', 'TTA', 'TTG'] 27 | I = ['ATT', 'ATC', 'ATA'] 28 | P = ['CCT', 'CCA', 'CCG', 'CCC'] 29 | F = ['TTT', 'TTC'] 30 | Y = ['TAT', 'TAC'] 31 | W = ['TGG'] 32 | S = ['TCT', 'TCA', 'TCC', 'TCG', 'AGT', 'AGC'] 33 | T = ['ACT', 'ACC', 'ACG', 'ACA'] 34 | M = ['ATG'] 35 | C = ['TGT', 'TGC'] 36 | N = ['AAT', 'AAC'] 37 | Q = ['CAA', 'CAG'] 38 | D = ['GAT', 'GAC'] 39 | E = ['GAA', 'GAG'] 40 | K = ['AAA', 'AAG'] 41 | R = ['CGT', 'CGC', 'CGG', 'CGA', 'AGA', 'AGG'] 42 | H = ['CAT', 'CAC'] 43 | X = ['TAA', 'TAG', 'TGA'] 44 | homonym_codon = {'G': G, 'A': A, 'V': V, 'L': L, 'I': I, 'P': P, 'F': F, 'Y': Y, 'W': W, 'S': S, 'T': T, 'M': M, 'C': C, 45 | 'N': N, 'Q': Q, 'D': D, 'E': E, 'K': K, 'R': R, 'H': H, 'X': X} 46 | 47 | 48 | def DNA_to_AA(DNA_seq): 49 | AA_list = "" 50 | start = 0 51 | end = 3 52 | DNA_seq = DNA_seq.replace('U', 'T') 53 | while (end <= len(DNA_seq) + 1): 54 | codon = DNA_seq[start:end] 55 | start += 3 56 | end += 3 57 | for AA, codons in homonym_codon.items(): 58 | if codon in codons: 59 | AA_list += AA 60 | return AA_list 61 | 62 | 63 | def readFa(fa): 64 | ''' 65 | @msg: Read a fasta file 66 | @param fa {str} Fasta file path 67 | @return: {generator} Returns a generator that can iterate over each sequence name and sequence in the fasta file 68 | ''' 69 | with open(fa, 'r') as FA: 70 | seqName, seq = '', '' 71 | while 1: 72 | line = FA.readline() 73 | line = line.strip('\n') 74 | if (line.startswith('>') or not line) and seqName: 75 | yield ((seqName, seq)) 76 | if line.startswith('>'): 77 | seqName = line[1:] 78 | seq = '' 79 | else: 80 | seq += line 81 | if not line: 82 | break 83 | 84 | 85 | def dna_to_amino_acid(dna_sequence): 86 | coding_dna = Seq(dna_sequence) 87 | return str(coding_dna.translate()) 88 | 89 | 90 | def delete_files_in_directory(directory): 91 | for filename in os.listdir(directory): 92 | file_path = os.path.join(directory, filename) 93 | if os.path.isfile(file_path): 94 | os.remove(file_path) 95 | 96 | 97 | def set_args(): 98 | parser = argparse.ArgumentParser(description='CodonBert:model train') 99 | parser.add_argument('-t', '--tsv_path', help='transcript_rna_tissue.tsv file <.tsv>', required=True, type=str) 100 | parser.add_argument('-l', '--lations', help='gencode.v43.pc_translations.fa.gz file <.gz>', required=True, type=str) 101 | parser.add_argument('-c', '--scripts', help='gencode.v43.pc_transcripts.fa.gz file <.gz>', required=True, type=str) 102 | parser.add_argument('-o', '--output_path', help='result save path', required=True, type=str) 103 | args = parser.parse_args() 104 | return args 105 | 106 | 107 | args = set_args() 108 | tsv_path = args.tsv_path 109 | AA_fa_gz_path = args.lations 110 | DNA_fa_gz_path = args.scripts 111 | result_save_path = args.output_path 112 | tissue_TPM = pd.read_csv(tsv_path, sep='\t') 113 | directory_path = './data/train_data' 114 | if (os.path.exists(directory_path) == False): 115 | os.makedirs(directory_path) 116 | 117 | # Retrieve the enstid with high TPM values for a specific tissue 118 | tpm_val = tissue_TPM.loc[:, tissue_TPM.columns.str.contains('kidney')] 119 | tpm_val['TPM_mean'] = tpm_val.apply(lambda x: np.median(x), axis=1) 120 | merged = pd.concat([tissue_TPM.loc[:, ['ensgid', 'enstid']], tpm_val], axis=1) 121 | kidney_beyong5_id = merged.loc[merged.TPM_mean > 5, ['ensgid', 'enstid']] 122 | kidney_beyong5_id = kidney_beyong5_id.reset_index(drop=True) 123 | kidney_beyong5_id_part = kidney_beyong5_id.loc[0:10, :] 124 | 125 | # Read compressed FASTA file 126 | AA_fa_dict = {} 127 | DNA_fa_dict = {} 128 | 129 | with gzip.open(AA_fa_gz_path, "rt") as handle: 130 | for record in SeqIO.parse(handle, "fasta"): 131 | ENSG_result = re.findall(r'\bENSG\d+\.\b', record.id)[0].split('.')[0] 132 | ENST_result = re.findall(r'\bENST\d+\.\b', record.id)[0].split('.')[0] 133 | 134 | AA_fa_dict[ENSG_result + '_' + ENST_result] = record.seq 135 | 136 | with gzip.open(DNA_fa_gz_path, "rt") as handle: 137 | for record in SeqIO.parse(handle, "fasta"): 138 | ENSG_result = re.findall(r'\bENSG\d+\.\b', record.id)[0].split('.')[0] 139 | ENST_result = re.findall(r'\bENST\d+\.\b', record.id)[0].split('.')[0] 140 | 141 | match = re.search(r"CDS:(\d+)-(\d+)", record.id) 142 | start_id = int(match.group(1)) 143 | end_id = int(match.group(2)) 144 | cds_DNA = str(record.seq)[start_id - 1:end_id] 145 | 146 | DNA_fa_dict[ENSG_result + '_' + ENST_result] = cds_DNA 147 | 148 | # Single test start 149 | kidney_beyong5 = pd.DataFrame(columns=['ensgid', 'enstid', 'AA_seq', 'DNA_seq']) 150 | for index, row in tqdm(kidney_beyong5_id.iterrows()): 151 | if row['ensgid'] + '_' + row['enstid'] in AA_fa_dict.keys() and row['ensgid'] + '_' + row[ 152 | 'enstid'] in DNA_fa_dict.keys(): 153 | AA_fa_id = row['ensgid'] 154 | AA_fa_seq = AA_fa_dict[row['ensgid'] + '_' + row['enstid']] 155 | fa_id = row['enstid'] 156 | cds_DNA = DNA_fa_dict[row['ensgid'] + '_' + row['enstid']] 157 | else: 158 | fa_id = None 159 | fa_seq = None 160 | 161 | # Extract complete DNA IDs and sequences 162 | if fa_id != None and cds_DNA != None: 163 | if len(cds_DNA) % 3 == 0: 164 | AA_cds_DNA = dna_to_amino_acid(cds_DNA) 165 | if len(AA_cds_DNA) == len(AA_fa_seq): 166 | if AA_cds_DNA == AA_fa_seq: 167 | AA_DNA_match = 1 168 | else: 169 | AA_DNA_match = 0 170 | else: 171 | if len(AA_cds_DNA) - 1 == len(AA_fa_seq): 172 | if AA_cds_DNA[0:len(AA_cds_DNA) - 1] == AA_fa_seq: 173 | AA_DNA_match = 1 174 | else: 175 | AA_DNA_match = 0 176 | 177 | if AA_DNA_match == 1: 178 | kidney_beyong5.loc[len(kidney_beyong5)] = [row['ensgid'], row['enstid'], 179 | str(AA_cds_DNA).strip('[').strip(']').replace("\'", ""), 180 | str(cds_DNA).strip('[').strip(']').replace("\'", "")] 181 | 182 | kidney_TPM_beyong5_all = kidney_beyong5 183 | cds_kidney_TPM_beyong5_all = kidney_beyong5['DNA_seq'] 184 | 185 | high_TPM_DNA_file_fasta = './data/train_data/csd_kidney_high_TPM_200_2000.fasta' 186 | fp = open(high_TPM_DNA_file_fasta, 'w') 187 | 188 | for index, row in kidney_TPM_beyong5_all.iterrows(): 189 | if 200 <= len(row['DNA_seq']) <= 2000 and len(row['DNA_seq']) % 3 == 0: 190 | seq_name = ">" + str(row['ensgid']) + '_' + (str)(row['enstid']) + "\n" 191 | seq = row['DNA_seq'] 192 | fp.write(seq_name) 193 | fp.write(seq + "\n") 194 | fp.close() 195 | 196 | cds_kidney_TPM_beyong5_select = [x for x in cds_kidney_TPM_beyong5_all if 200 <= len(x) <= 2000 and len(x) % 3 == 0] 197 | 198 | high_TPM_DNA_file_fasta = "./data/train_data/csd_kidney_high_TPM_200_2000.fasta" 199 | fasta_path = high_TPM_DNA_file_fasta 200 | 201 | # temp file 202 | temp_file = "temp.txt" 203 | 204 | command_line = "RNAfold -p --MEA < %s > %s" % (fasta_path, temp_file) 205 | os.system(command_line) 206 | mfe_dict = extract_MFE(temp_file) 207 | os.remove(temp_file) 208 | 209 | for name in os.listdir(sys.path[0].split('/draw_pictures')[0]): 210 | if name.endswith(('.ps')): 211 | os.remove(sys.path[0].split('/draw_pictures')[0] + '/' + name) 212 | for name in os.listdir(sys.path[0]): 213 | if name.endswith(('.ps')): 214 | os.remove(sys.path[0] + '/' + name) 215 | 216 | # Calculate CAI 217 | command_line = "_cai -seqall %s -cfile Ehuman.cut -outfile %s" % (fasta_path, temp_file) 218 | os.system(command_line) 219 | cai_dict = extract_cai(temp_file) 220 | os.remove(temp_file) 221 | # Integration of calculated results 222 | metrics_result = merge_func(['MFE', 'CAI'], mfe_dict, cai_dict) 223 | kidney_CAI_MFE = metrics_result 224 | DNA_fasta_path = "./data/train_data/csd_kidney_high_TPM_200_2000.fasta" 225 | MFE200_names = kidney_CAI_MFE.loc[kidney_CAI_MFE['MFE'] < -200, 'name'].tolist() 226 | CAI7_names = kidney_CAI_MFE.loc[kidney_CAI_MFE['CAI'] > 0.7, 'name'].tolist() 227 | get_names = list(set(MFE200_names) & set(CAI7_names)) 228 | 229 | # Read FASTA file and extract sequences based on selected names 230 | sequences_names = [] 231 | sequences = [] 232 | for record in SeqIO.parse(DNA_fasta_path, "fasta"): 233 | if record.id in get_names: 234 | sequences_names.append(record.id) 235 | sequences.append(record.seq) 236 | 237 | np.save('./data/train_data/csd_kidney_high_TPM_200_2000_MFE200_CAI7.npy', sequences) 238 | 239 | # Save sequences as FASTA file 240 | fasta_path = "./data/train_data/csd_kidney_high_TPM_200_2000_MFE200_CAI7.fasta" 241 | fp = open(fasta_path, 'w') 242 | for i in range(len(sequences_names)): 243 | seq_name = ">" + (str)(sequences_names[i]) + "\n" 244 | seq = str(sequences[i]) 245 | fp.write(seq_name) 246 | fp.write(seq + "\n") 247 | fp.close() 248 | DNA_seq_dataset = sequences 249 | 250 | # Load unoptimized fasta file 251 | fasta_file = "./data/train_data/csd_kidney_high_TPM_200_2000_MFE200_CAI7.fasta" 252 | Seq_name = [] 253 | Seq = [] 254 | seq_num = 0 255 | for seqName, seq in readFa(fasta_file): 256 | Seq_name.append(seqName) 257 | Seq.append(seq) 258 | AA_list = [] 259 | for idx in range(len(DNA_seq_dataset)): 260 | DNA_seq_dataset[idx] = DNA_seq_dataset[idx].replace('U', 'T') 261 | AA_seq = DNA_to_AA(DNA_seq_dataset[idx]) 262 | AA_list.append(AA_seq) 263 | 264 | # Generate a dictionary based on the Jcat concept, where one amino acid corresponds to one codon 265 | keys = 'GAVLIPFWYCMSTNQDEKRHX' 266 | values = 'GGCGCCGTGCTGATCCCCTTCTGGTACTGCATGAGCACCAACCAGGACGAGAAGCGCCACTAG' 267 | Jcat_dict = {keys[i]: values[i * 3:i * 3 + 3] for i in range(len(keys))} 268 | Jcat_opt_result = [] 269 | for AA_seq in AA_list: 270 | AA_to_mRNA = ''.join([Jcat_dict.get(char, char) for char in AA_seq]) 271 | Jcat_opt_result.append(AA_to_mRNA) 272 | 273 | # Store the entire sequence file as fasta 274 | fasta_path = "./data/train_data/Jcat_kidney_high_TPM_200_2000.fasta" 275 | fp = open(fasta_path, 'w') 276 | for i in range(len(Seq_name)): 277 | seq_name = ">" + (str)(Seq_name[i]) + "\n" 278 | seq = str(Jcat_opt_result[i]) 279 | fp.write(seq_name) 280 | fp.write(seq + "\n") 281 | fp.close() 282 | 283 | high_TPM_DNA_file_fasta = "./data/train_data/Jcat_kidney_high_TPM_200_2000.fasta" 284 | fasta_path = high_TPM_DNA_file_fasta 285 | # temp file 286 | temp_file = "temp.txt" 287 | command_line = "RNAfold -p --MEA < %s > %s" % (fasta_path, temp_file) 288 | os.system(command_line) 289 | mfe_dict = extract_MFE(temp_file) 290 | os.remove(temp_file) 291 | 292 | for name in os.listdir(sys.path[0].split('/draw_pictures')[0]): 293 | if name.endswith(('.ps')): 294 | os.remove(sys.path[0].split('/draw_pictures')[0] + '/' + name) 295 | for name in os.listdir(sys.path[0]): 296 | if name.endswith(('.ps')): 297 | os.remove(sys.path[0] + '/' + name) 298 | 299 | # Calculate CAI 300 | command_line = "_cai -seqall %s -cfile Ehuman.cut -outfile %s" % (fasta_path, temp_file) 301 | os.system(command_line) 302 | cai_dict = extract_cai(temp_file) 303 | os.remove(temp_file) 304 | metrics_result = merge_func(['MFE', 'CAI'], mfe_dict, cai_dict) 305 | kidney_CAI_MFE = metrics_result 306 | DNA_fasta_path = "./data/train_data/Jcat_kidney_high_TPM_200_2000.fasta" 307 | 308 | MFE200_names = kidney_CAI_MFE.loc[kidney_CAI_MFE['MFE'] < -200, 'name'].tolist() 309 | CAI7_names = kidney_CAI_MFE.loc[kidney_CAI_MFE['CAI'] > 0.7, 'name'].tolist() 310 | get_names = list(set(MFE200_names) & set(CAI7_names)) 311 | 312 | # Read FASTA file and extract sequences based on selected names 313 | Jcat_sequences_names = [] 314 | Jcat_sequences = [] 315 | for record in SeqIO.parse(DNA_fasta_path, "fasta"): 316 | if record.id in get_names: 317 | Jcat_sequences_names.append(record.id) 318 | Jcat_sequences.append(record.seq) 319 | 320 | np.save('./data/train_data/Jcat_kidney_high_TPM_200_2000_MFE200_CAI7.npy', Jcat_sequences) 321 | DNA_file = "./data/train_data/csd_kidney_high_TPM_200_2000_MFE200_CAI7.npy" 322 | Jcat_DNA_file = "./data/train_data/Jcat_kidney_high_TPM_200_2000_MFE200_CAI7.npy" 323 | DNA_seq_dataset = np.load(DNA_file, allow_pickle=True) 324 | Jcat_DNA_dataset = np.load(Jcat_DNA_file, allow_pickle=True) 325 | 326 | real_seqs = DNA_seq_dataset[0:10000] 327 | real_seqs_len = len(real_seqs) 328 | test_real_seqs = DNA_seq_dataset[real_seqs_len:(DNA_seq_dataset.shape[0])] 329 | Jcat_DNA_dataset_02 = Jcat_DNA_dataset[0:int(real_seqs_len*0.2)] 330 | Jcat_DNA_dataset_05 = Jcat_DNA_dataset[0:int(real_seqs_len*0.5)] 331 | Jcat_DNA_dataset_1 = Jcat_DNA_dataset[0:int(real_seqs_len)] 332 | 333 | real_Jcat_02 = list(real_seqs)+list(Jcat_DNA_dataset_02) 334 | real_Jcat_05 = list(real_seqs)+list(Jcat_DNA_dataset_05) 335 | real_Jcat_1 = list(real_seqs)+list(Jcat_DNA_dataset_1) 336 | 337 | directory_path = './data/train_data' 338 | delete_files_in_directory(directory_path) 339 | 340 | if (os.path.exists(result_save_path) == False): 341 | os.makedirs(result_save_path) 342 | 343 | np.save(result_save_path + '/kidney_train_DNA_real_seqs.npy', real_seqs) 344 | np.save(result_save_path + '/kidney_train_DNA_real_Jcat_02.npy', real_Jcat_02) 345 | np.save(result_save_path + '/kidney_train_DNA_real_Jcat_05.npy', real_Jcat_05) 346 | np.save(result_save_path + '/kidney_train_DNA_real_Jcat_1.npy', real_Jcat_1) 347 | np.save(result_save_path + '/kidney_test_DNA_real_seqs.npy', test_real_seqs) -------------------------------------------------------------------------------- /scripts/fasta_to_npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | 4 | def readFa(fa): 5 | with open(fa,'r') as FA: 6 | seqName,seq='','' 7 | while 1: 8 | line=FA.readline() 9 | line=line.strip('\n') 10 | if (line.startswith('>') or not line) and seqName: 11 | yield((seqName,seq)) 12 | if line.startswith('>'): 13 | seqName = line[1:] 14 | seq='' 15 | else: 16 | seq+=line 17 | if not line: 18 | break 19 | def set_args(): 20 | parser = argparse.ArgumentParser(description='fasta to npy') 21 | parser.add_argument('-f', '--fasta_path', help='the seq fasta',required = True, type = str) 22 | parser.add_argument('-n', '--npy_path', help='the save path',required = True, type = str) 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | if __name__ == '__main__': 28 | args = set_args() 29 | fasta_path = args.fasta_path 30 | npy_path = args.npy_path 31 | 32 | 33 | Seq = [] 34 | for seqName,seq in readFa(fasta_path): 35 | Seq.append(seq) 36 | 37 | seq_dataset = np.array(Seq, dtype=object) 38 | np.save(npy_path, seq_dataset) 39 | # np.save(npy_path, Seq) -------------------------------------------------------------------------------- /scripts/get_metrics.py: -------------------------------------------------------------------------------- 1 | import re, os 2 | import pandas as pd 3 | import argparse 4 | import sys 5 | 6 | def set_args(): 7 | parser = argparse.ArgumentParser(description='calculate MFE/CAI metrics') 8 | parser.add_argument('-f', '--fasta', help='mRNA fasta',required = True, type = str) 9 | parser.add_argument('-o', '--output_path', help='metrics result path',required = True, type = str) 10 | args = parser.parse_args() 11 | return args 12 | 13 | def fasta_reader(fasta_path): 14 | fasta = {} 15 | with open(fasta_path, 'r') as file: 16 | for line in file: 17 | striped = line.strip() 18 | if striped.startswith('>'): 19 | seq_name = striped[1:] 20 | fasta[ seq_name ] = "" 21 | else: 22 | fasta[ seq_name ] += striped 23 | return fasta 24 | 25 | def GC_con(fasta_dict): 26 | gc_dict = {} 27 | for name, seq in fasta_dict.items(): 28 | length = len(seq) 29 | G_num = seq.count('G') 30 | C_num = seq.count('C') 31 | GC_content = (G_num+C_num)/length 32 | gc_dict[name] = round(GC_content, 4) 33 | return gc_dict 34 | 35 | def extract_MFE(rnafold_res): 36 | mfe_dict = {} 37 | with open(rnafold_res, 'r') as f: 38 | for ind, line in enumerate(f): 39 | striped = line.strip() 40 | if striped.startswith('>'): 41 | seq_name = striped[1:] 42 | 43 | if ind % 7 == 2: 44 | mfe_score = float( re.findall('[+-]\d+.\d+', line)[0] ) 45 | mfe_dict[seq_name] = mfe_score 46 | 47 | return mfe_dict 48 | 49 | def extract_cai(cai_res): 50 | cai_dict = {} 51 | with open(cai_res) as f: 52 | for line in f: 53 | seq_name = re.findall('Sequence: (.*) CAI:', line)[0] 54 | cai_score = float( re.findall('CAI: (\d+.\d+)', line)[0] ) 55 | cai_dict[seq_name] = cai_score 56 | return cai_dict 57 | 58 | def extract_enc(enc_res): 59 | enc_dict = {} 60 | with open(enc_res) as f: 61 | for line in f: 62 | seq_name = re.findall('([^\s]+)\s', line)[0] 63 | enc_score = float( re.findall('= (\d+.\d+)', line)[0] ) 64 | enc_dict[seq_name] = enc_score 65 | return enc_dict 66 | 67 | def merge_func(col_names, *dicts): 68 | col_names = ['name'] + col_names 69 | first_dict = dicts[0] 70 | output = [] 71 | for ind, (key, val) in enumerate(first_dict.items()): 72 | output.append( [key, val] + [dct[key] for dct in dicts[1:]] ) 73 | output = pd.DataFrame(output, columns=col_names) 74 | return output 75 | 76 | 77 | if __name__ == '__main__': 78 | args = set_args() 79 | fasta_path = args.fasta 80 | temp_file = "temp.txt" 81 | fasta_dict = fasta_reader(fasta_path) 82 | gc_dict = GC_con(fasta_dict) 83 | command_line = "RNAfold -p --MEA < %s > %s" % (fasta_path, temp_file) 84 | os.system(command_line) 85 | mfe_dict = extract_MFE(temp_file) 86 | os.remove(temp_file) 87 | if any(name.endswith(('.ps')) for name in os.listdir(sys.path[0])): 88 | r = os.system('/usr/bin/find ./ -type f -name "*.ps" | xargs /usr/bin/rm') 89 | 90 | command_line = "_cai -seqall %s -cfile Ehuman.cut -outfile %s" % (fasta_path, temp_file) 91 | os.system(command_line) 92 | cai_dict = extract_cai(temp_file) 93 | os.remove(temp_file) 94 | 95 | command_line = "_chips -seqall %s -outfile %s -nosum" % (fasta_path, temp_file) 96 | os.system(command_line) 97 | enc_dict = extract_enc(temp_file) 98 | os.remove(temp_file) 99 | 100 | metrics_result = merge_func(['GC', 'MFE', 'CAI', 'ENC'], gc_dict, mfe_dict, cai_dict, enc_dict) 101 | seq_name = list(gc_dict.keys())[0].split('_')[0] 102 | save_path = args.output_path 103 | metrics_result.to_csv(save_path,index=False,header=True) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | from tqdm import tqdm 9 | from scripts.codon_bert_pytorch import CodonBERT, PretrainingWrapper 10 | from scripts.config_function import * 11 | import argparse 12 | 13 | def set_args(): 14 | parser = argparse.ArgumentParser(description='CodonBert:model train') 15 | parser.add_argument('-t', '--train', help='the train mRNA seq file <.fasta>',required = True, type = str) 16 | parser.add_argument('-v', '--validation', help='the validation mRNA seq file <.fasta>',required = True, type = str) 17 | parser.add_argument('-m', '--model', help='the model path', required=True, type=str) 18 | args = parser.parse_args() 19 | return args 20 | 21 | def DNA_encode_input(DNA_seq_dataset): 22 | AA_list = [] 23 | for idx in range(len(DNA_seq_dataset)): 24 | DNA_seq_dataset[idx] = DNA_seq_dataset[idx].replace('U','T') 25 | AA_seq = DNA_to_AA(DNA_seq_dataset[idx]) 26 | AA_list.append(AA_seq) 27 | AA_seq_dataset = np.array(AA_list,dtype=object) 28 | int_list = [] 29 | for idx in range(len(DNA_seq_dataset)): 30 | DNA_seq_dataset[idx] = DNA_seq_dataset[idx].replace('U','T') 31 | int_seq = sample_DNA_to_int(DNA_seq_dataset[idx],0) 32 | int_seq = str(int_seq).replace('[','').replace(']','').replace(',','').replace(' ','').replace('\'','') 33 | int_list.append(int_seq) 34 | DNA_int_result, sample_weigths = encode_seq_Y_68(int_list,seq_len,is_binary) 35 | AA_int_list = [] 36 | for idx in range(len(AA_seq_dataset)): 37 | AA_seq = tokenize_seq(AA_seq_dataset[idx], seq_len, aa_to_token_index, additional_token_to_index) 38 | AA_int_list.append(AA_seq) 39 | return AA_int_list, DNA_int_result, sample_weigths 40 | 41 | if __name__ == '__main__': 42 | 43 | if torch.cuda.is_available(): 44 | device = 'cuda' 45 | else: 46 | device = 'cpu' 47 | 48 | 49 | warnings.filterwarnings("ignore") 50 | model = CodonBERT( 51 | num_tokens = 25, 52 | num_annotation_class = 67, 53 | dim = 128, 54 | dim_global = 128, 55 | depth = 6, 56 | narrow_conv_kernel = 9, 57 | wide_conv_kernel = 9, 58 | wide_conv_dilation = 5, 59 | attn_heads = 8, 60 | attn_dim_head = 64, 61 | local_to_global_attn = False, 62 | local_self_attn = True, 63 | num_global_tokens = 2, 64 | glu_conv = False 65 | ) 66 | model = model.to(device) 67 | 68 | learner = PretrainingWrapper( 69 | model, 70 | seq_length = 1024, 71 | random_replace_token_prob = 0.1, 72 | remove_annotation_prob = 0.25, 73 | add_annotation_prob = 0.01, 74 | remove_all_annotations_prob = 0.5, 75 | seq_loss_weight = 1., 76 | annotation_loss_weight = 1., 77 | exclude_token_ids = (0, 1, 2), 78 | RNA_exclude_token_ids = (0, 1, 2) 79 | ) 80 | 81 | args = set_args() 82 | train_DNA_file = args.train 83 | test_DNA_file = args.validation 84 | model_save_path = args.model 85 | 86 | train_Seq = [] 87 | for seqName, seq in readFa(train_DNA_file): 88 | train_Seq.append(seq) 89 | train_DNA_seq_dataset = np.array(train_Seq, dtype=object) 90 | test_Seq = [] 91 | for seqName, seq in readFa(test_DNA_file): 92 | test_Seq.append(seq) 93 | test_DNA_seq_dataset = np.array(test_Seq, dtype=object) 94 | 95 | train_AA_int_list, train_DNA_int_result, train_sample_weigths = DNA_encode_input(train_DNA_seq_dataset) 96 | test_AA_int_list, test_DNA_int_result, test_sample_weigths = DNA_encode_input(test_DNA_seq_dataset) 97 | 98 | AA_int = np.array(train_AA_int_list).astype(int) 99 | DNA_int = np.array(train_DNA_int_result).astype(int) 100 | test_AA_int = np.array(test_AA_int_list).astype(int) 101 | test_DNA_int = np.array(test_DNA_int_result).astype(int) 102 | train_mask_int = train_sample_weigths.astype(int) 103 | test_mask_int = test_sample_weigths.astype(int) 104 | 105 | AA_float_tensor = torch.tensor(AA_int, dtype=torch.int64) 106 | DNA_float_tensor = torch.tensor(DNA_int, dtype=torch.int64) 107 | test_AA_float_tensor = torch.tensor(test_AA_int, dtype=torch.int64) 108 | test_DNA_float_tensor = torch.tensor(test_DNA_int, dtype=torch.int64) 109 | train_mask_bool_tensor = torch.tensor(train_mask_int, dtype=torch.bool) 110 | test_mask_bool_tensor = torch.tensor(test_mask_int, dtype=torch.bool) 111 | 112 | torch_dataset = data.TensorDataset(AA_float_tensor, DNA_float_tensor, train_mask_bool_tensor) 113 | test_torch_dataset = data.TensorDataset(test_AA_float_tensor, test_DNA_float_tensor, test_mask_bool_tensor) 114 | train_iter = data.DataLoader(dataset=torch_dataset, batch_size=batch_size, shuffle=True, num_workers=2) 115 | test_train_iter = data.DataLoader(dataset=test_torch_dataset, batch_size=1, shuffle=True, num_workers=2) 116 | 117 | criterion = nn.CrossEntropyLoss() 118 | optimizer = torch.optim.Adam(model.parameters(), lr=lr_init) 119 | scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) 120 | 121 | 122 | for epoch in tqdm(range(N_EPOCHS)): 123 | start_time = time.time() 124 | for ind, (src, trg, mask) in enumerate(train_iter): 125 | optimizer.zero_grad() 126 | src = src.to(device) 127 | trg = trg.to(device) 128 | mask = mask.to(device) 129 | loss, seq_loss, annotation_loss = learner(src, trg, epoch, mask = mask) 130 | 131 | writer.add_scalar('loss',loss.item(),tensorboard_ind) 132 | writer.add_scalar('seq_loss',seq_loss.item(),tensorboard_ind) 133 | writer.add_scalar('annotation_loss',annotation_loss.item(),tensorboard_ind) 134 | tensorboard_ind += 1 135 | loss.backward() 136 | optimizer.step() 137 | scheduler.step(loss) 138 | valid_tensorboard_ind, valid_writer = self_evaluate(model, test_train_iter, valid_tensorboard_ind, valid_writer) 139 | ACC_tensorboard_ind, ACC_writer = AA_acc(model, test_train_iter, ACC_tensorboard_ind, ACC_writer) 140 | end_time = time.time() 141 | if (epoch+1)==320: 142 | torch.save(model.state_dict(), model_save_path+'_'+str(epoch+1)+'_model_param.pt') 143 | elif (epoch+1)==N_EPOCHS: 144 | torch.save(model.state_dict(), model_save_path+'_'+str(epoch+1)+'_model_param.pt') 145 | --------------------------------------------------------------------------------