├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── mol_prob_transformer.yml ├── mol_transformer.yml ├── rna_prob_transformer.yml ├── rna_transformer.yml ├── ssd_prob_transformer.yml └── ssd_transformer.yml ├── data ├── .keep ├── TS0.plk └── rna_data.plk.xz ├── environment.yml ├── infer_rna_folding.py ├── prob_transformer ├── __init__.py ├── data │ ├── __init__.py │ ├── dummy_handler.py │ ├── iterator.py │ ├── mol_handler.py │ ├── rna_handler.py │ └── ssd_handler.py ├── evaluation │ ├── __init__.py │ ├── cnn_head │ │ ├── __init__.py │ │ ├── infere_transformer.py │ │ └── train_cnn_from_inference.py │ ├── eval_transformer.py │ ├── metrics │ │ ├── __init__.py │ │ ├── fpscores.pkl.gz │ │ ├── sascore.py │ │ └── toy_task_survey.py │ └── statistics_center.py ├── model │ ├── __init__.py │ └── probtransformer.py ├── module │ ├── __init__.py │ ├── attention.py │ ├── embedding.py │ ├── feed_forward.py │ ├── geco_criterion.py │ ├── mat_head.py │ ├── optim_builder.py │ ├── probabilistic_forward.py │ ├── probformer_block.py │ └── probformer_stack.py ├── routine │ ├── __init__.py │ ├── evaluation.py │ └── training.py ├── train_transformer.py └── utils │ ├── __init__.py │ ├── config_init.py │ ├── handler │ ├── __init__.py │ ├── base_handler.py │ ├── checkpoint.py │ ├── config.py │ └── folder.py │ ├── logger.py │ ├── summary.py │ ├── supporter.py │ └── torch_utils.py ├── probtransformer.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Experiment Folders 2 | thirdparty_algorithm/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | venv/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | .myenv 92 | venv/ 93 | myenv/ 94 | ENV/ 95 | venv* 96 | myenv* 97 | .venv* 98 | .myenv* 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | # default data folder 114 | data_babi/ 115 | data_cnn/ 116 | data_tmp/ 117 | 118 | # pycharm 119 | .idea/ 120 | .run/ 121 | 122 | # folder 123 | .experiments/* 124 | .experiment/* 125 | .tmp/* 126 | .data/* 127 | data/* 128 | 129 | # pickle 130 | *.tgz 131 | *.de 132 | /.data/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2022 Joerg K.H. Franke 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Transformer 2 | ### *Modelling Ambiguities and Distributions for RNA Folding and Molecule Design* 3 | ____ 4 | 5 | This repository contains the source code to the NeurIPS 2022 paper 6 | *Probabilistic Transformer: Modelling Ambiguities and Distributions for RNA Folding and Molecule Design* 7 | 8 | [Paper on arXiv](https://arxiv.org/abs/2205.13927) 9 | 10 | ## Structure of the repository 11 | 12 | ##### *configs* 13 | Contains the configuration files for our experiments on the Synthetic Sequential Distribution, RNA 14 | folding and molecule design task we reported in the paper. 15 | 16 | ##### *data* 17 | Contains training, validation and test data for the RNA folding and molecule design task. 18 | We use the processed Guacamol dataset from https://github.com/devalab/molgpt and created the RNA folding data based on the description in the paper. 19 | 20 | ##### *prob_transformer* 21 | Contains the source code of the ProbTransformer, the data handler and the training script `train_transformer.py`. 22 | The train script runs out of the box on a downscaled config and creates an *experiments* folder in the base directory. 23 | 24 | 25 | ## Install conda environment 26 | 27 | Please adjust the cuda toolkit version in the `environment.yml` file to fit your setup. 28 | ``` 29 | conda env create -n ptenv -f environment.yml 30 | conda activate ptenv 31 | pip install -e . 32 | ``` 33 | 34 | ## Prepare data 35 | 36 | #### RNA data 37 | ``` 38 | tar -xf data/rna_data.plk.xz -C data/ 39 | ``` 40 | 41 | #### Molecule data 42 | 43 | Download the [Guacamol dataset](https://drive.google.com/file/d/1gOSoKyGoYVdxtvy5cH2GNVDpLibk0lkS/view?usp=sharing) and extract into `data`. 44 | ``` 45 | unzip data/guacamol2.csv.zip -d data/ 46 | ``` 47 | 48 | ## Model Checkpoints 49 | 50 | Please find checkpoints for the ProbTransformer and the CNN head for RNA folding at: 51 | ``` 52 | https://ml.informatik.uni-freiburg.de/research-artifacts/probtransformer/prob_transformer_final.pth 53 | https://ml.informatik.uni-freiburg.de/research-artifacts/probtransformer/cnn_head_final.pth 54 | ``` 55 | 56 | 57 | ## Use the ProbTransformer for RNA folding 58 | 59 | Please use the **infer_rna_folding.py** script to fold a sequence of nucleotides (ACGU). 60 | ``` 61 | python infer_rna_folding.py -s ACGUCCUGUGCGAGCAUGCAUGC 62 | ``` 63 | 64 | To evaluate the uploaded model checkpoints on the test set TS0 use the **evaluate** flag. 65 | ``` 66 | python infer_rna_folding.py -e 67 | ``` 68 | 69 | ## Train a ProbTransformer/Transformer model 70 | ##### on the Synthetic Sequential Distribution Task 71 | ``` 72 | python prob_transformer/train_transformer.py -c configs/ssd_prob_transformer.yml 73 | python prob_transformer/train_transformer.py -c configs/ssd_transformer.yml 74 | ``` 75 | ##### on the RNA folding Task 76 | ``` 77 | python prob_transformer/train_transformer.py -c configs/rna_prob_transformer.yml 78 | python prob_transformer/train_transformer.py -c configs/rna_transformer.yml 79 | ``` 80 | ##### on the Molecule Design Task 81 | ``` 82 | python prob_transformer/train_transformer.py -c configs/mol_prob_transformer.yml 83 | python prob_transformer/train_transformer.py -c configs/mol_transformer.yml 84 | ``` 85 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/__init__.py -------------------------------------------------------------------------------- /configs/mol_prob_transformer.yml: -------------------------------------------------------------------------------- 1 | expt: 2 | experiment_name: prob_transformer 3 | save_model: true 4 | data: 5 | batch_size: 8000 6 | mol: 7 | block_size: 100 8 | data_dir: data/guacamol2.csv 9 | gen_size: 10000 10 | max_length: 100 11 | min_length: 10 12 | props: 13 | - tpsa 14 | - logp 15 | - sas 16 | seed: 7752 17 | seed: 7752 18 | type: mol 19 | geco_criterion: 20 | kappa: 0.1 21 | kappa_adaption: true 22 | lagmul_rate: 0.01 23 | ma_decay: 0.95 24 | model: 25 | dropout: 0.1 26 | ff_factor: 4 27 | max_len: 100 28 | model_dim: 256 29 | model_type: prob_decoder 30 | n_layers: 8 31 | num_head: 8 32 | prob_layer: middle 33 | z_factor: 0.25 34 | zero_init: true 35 | optim: 36 | beta1: 0.9 37 | beta2: 0.98 38 | clip_grad: 1000 39 | lr_high: 0.0005 40 | lr_low: 5.0e-05 41 | optimizer: adamW 42 | scheduler: cosine 43 | warmup_epochs: 1 44 | weight_decay: 0.01 45 | train: 46 | amp: true 47 | epochs: 60 48 | grad_scale: 65536.0 49 | iter_per_epoch: 5000 50 | n_sampling: 10 51 | save_freq: 10 52 | eval_freq: 10 53 | seed: 7752 54 | -------------------------------------------------------------------------------- /configs/mol_transformer.yml: -------------------------------------------------------------------------------- 1 | expt: 2 | experiment_name: transformer 3 | save_model: true 4 | data: 5 | batch_size: 8000 6 | mol: 7 | block_size: 100 8 | data_dir: data/guacamol2.csv 9 | gen_size: 10000 10 | max_length: 100 11 | min_length: 10 12 | props: 13 | - tpsa 14 | - logp 15 | - sas 16 | seed: 7752 17 | seed: 7752 18 | type: mol 19 | geco_criterion: 20 | kappa: 1 21 | kappa_adaption: false 22 | lagmul_rate: 0.01 23 | ma_decay: 0.95 24 | model: 25 | dropout: 0.1 26 | ff_factor: 4 27 | max_len: 100 28 | model_dim: 256 29 | model_type: decoder 30 | n_layers: 8 31 | num_head: 8 32 | prob_layer: middle 33 | z_factor: 0 34 | zero_init: true 35 | optim: 36 | beta1: 0.9 37 | beta2: 0.98 38 | clip_grad: 1000 39 | lr_high: 0.0005 40 | lr_low: 5.0e-05 41 | optimizer: adamW 42 | scheduler: cosine 43 | warmup_epochs: 1 44 | weight_decay: 0.01 45 | train: 46 | amp: true 47 | epochs: 60 48 | grad_scale: 65536.0 49 | iter_per_epoch: 5000 50 | n_sampling: 10 51 | save_freq: 10 52 | eval_freq: 10 53 | seed: 7752 54 | -------------------------------------------------------------------------------- /configs/rna_prob_transformer.yml: -------------------------------------------------------------------------------- 1 | expt: 2 | experiment_name: prob_transformer 3 | save_model: true 4 | data: 5 | batch_size: 4000 6 | rna: 7 | df_path: data/rna_data.plk 8 | df_set_name: train 9 | max_length: 500 10 | min_length: 20 11 | similarity: 80 12 | seed: 5874 13 | type: rna 14 | geco_criterion: 15 | kappa: 0.1 16 | kappa_adaption: true 17 | lagmul_rate: 0.1 18 | ma_decay: 0.95 19 | model: 20 | dropout: 0.1 21 | ff_factor: 4 22 | max_len: 500 23 | model_dim: 512 24 | model_type: prob_encoder 25 | n_layers: 6 26 | num_head: 8 27 | prob_layer: middle 28 | z_factor: 1.0 29 | zero_init: true 30 | optim: 31 | beta1: 0.9 32 | beta2: 0.98 33 | clip_grad: 100 34 | lr_high: 0.0005 35 | lr_low: 5.0e-05 36 | optimizer: adamW 37 | scheduler: cosine 38 | warmup_epochs: 1 39 | weight_decay: 0.01 40 | train: 41 | amp: false 42 | epochs: 100 43 | grad_scale: 65536.0 44 | iter_per_epoch: 10000 45 | n_sampling: 10 46 | save_freq: 10 47 | eval_freq: 10 48 | seed: 5874 49 | -------------------------------------------------------------------------------- /configs/rna_transformer.yml: -------------------------------------------------------------------------------- 1 | expt: 2 | experiment_name: transformer 3 | save_model: true 4 | data: 5 | batch_size: 4000 6 | rna: 7 | df_path: data/rna_data.plk 8 | df_set_name: train 9 | max_length: 500 10 | min_length: 20 11 | similarity: 80 12 | seed: 5874 13 | type: rna 14 | geco_criterion: 15 | kappa: 0.1 16 | kappa_adaption: true 17 | lagmul_rate: 0.1 18 | ma_decay: 0.95 19 | model: 20 | dropout: 0.1 21 | ff_factor: 4 22 | max_len: 500 23 | model_dim: 512 24 | model_type: encoder 25 | n_layers: 6 26 | num_head: 8 27 | prob_layer: middle 28 | z_factor: 1.0 29 | zero_init: true 30 | optim: 31 | beta1: 0.9 32 | beta2: 0.98 33 | clip_grad: 100 34 | lr_high: 0.0005 35 | lr_low: 5.0e-05 36 | optimizer: adamW 37 | scheduler: cosine 38 | warmup_epochs: 1 39 | weight_decay: 0.01 40 | train: 41 | amp: false 42 | epochs: 100 43 | grad_scale: 65536.0 44 | iter_per_epoch: 10000 45 | n_sampling: 10 46 | save_freq: 10 47 | eval_freq: 10 48 | seed: 5874 49 | -------------------------------------------------------------------------------- /configs/ssd_prob_transformer.yml: -------------------------------------------------------------------------------- 1 | expt: 2 | experiment_name: prob_transformer 3 | save_model: true 4 | data: 5 | batch_size: 6000 6 | seed: 5193 7 | ssd: 8 | max_len: 90 9 | min_len: 15 10 | n_eval: 50 11 | n_sentence: 1000 12 | sample_amount: 100000 13 | seed: 100 14 | sentence_len: 3 15 | sentence_variations: 10 16 | src_vocab_size: 500 17 | trg_vocab_size: 500 18 | type: ssd 19 | geco_criterion: 20 | kappa: 0.1 21 | kappa_adaption: true 22 | lagmul_rate: 0.1 23 | ma_decay: 0.98 24 | model: 25 | dropout: 0.1 26 | ff_factor: 4 27 | max_len: 200 28 | model_dim: 256 29 | model_type: prob_encoder 30 | n_layers: 4 31 | num_head: 4 32 | prob_layer: all 33 | z_factor: 1.0 34 | zero_init: true 35 | optim: 36 | beta1: 0.9 37 | beta2: 0.98 38 | clip_grad: 100 39 | lr_high: 0.001 40 | lr_low: 0.0001 41 | optimizer: adamW 42 | scheduler: cosine 43 | warmup_epochs: 1 44 | weight_decay: 0.01 45 | train: 46 | amp: true 47 | epochs: 50 48 | grad_scale: 65536.0 49 | iter_per_epoch: 2000 50 | n_sampling: 10 51 | save_freq: 10 52 | eval_freq: 1 53 | seed: 5193 54 | -------------------------------------------------------------------------------- /configs/ssd_transformer.yml: -------------------------------------------------------------------------------- 1 | expt: 2 | experiment_name: transformer 3 | save_model: true 4 | data: 5 | batch_size: 6000 6 | seed: 5193 7 | ssd: 8 | max_len: 90 9 | min_len: 15 10 | n_eval: 50 11 | n_sentence: 1000 12 | sample_amount: 100000 13 | seed: 100 14 | sentence_len: 3 15 | sentence_variations: 10 16 | src_vocab_size: 500 17 | trg_vocab_size: 500 18 | type: ssd 19 | geco_criterion: 20 | kappa: 0.1 21 | kappa_adaption: true 22 | lagmul_rate: 0.1 23 | ma_decay: 0.95 24 | model: 25 | dropout: 0.1 26 | ff_factor: 4 27 | max_len: 200 28 | model_dim: 256 29 | model_type: encoder 30 | n_layers: 4 31 | num_head: 4 32 | prob_layer: all 33 | z_factor: 1.0 34 | zero_init: true 35 | optim: 36 | beta1: 0.9 37 | beta2: 0.98 38 | clip_grad: 100 39 | lr_high: 0.001 40 | lr_low: 0.0001 41 | optimizer: adamW 42 | scheduler: cosine 43 | warmup_epochs: 1 44 | weight_decay: 0.01 45 | train: 46 | amp: true 47 | epochs: 50 48 | grad_scale: 65536.0 49 | iter_per_epoch: 2000 50 | n_sampling: 10 51 | save_freq: 10 52 | eval_freq: 10 53 | seed: 5193 54 | -------------------------------------------------------------------------------- /data/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/data/.keep -------------------------------------------------------------------------------- /data/TS0.plk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/data/TS0.plk -------------------------------------------------------------------------------- /data/rna_data.plk.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/data/rna_data.plk.xz -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - rdkit 4 | - conda-forge 5 | dependencies: 6 | - cudatoolkit=11.6 7 | - python=3.8 8 | - pytorch=1.12.1 9 | - molsets=0.3.1 10 | - rdkit=2022.03.5 11 | - pandas=1.5.0 12 | - pip: 13 | - distance==0.1.3 14 | - pyyaml==5.4.1 15 | - tqdm==4.62.3 16 | - wget==3.2 17 | 18 | -------------------------------------------------------------------------------- /infer_rna_folding.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import wget 3 | from tqdm import tqdm 4 | from collections import defaultdict 5 | import numpy as np 6 | import torch 7 | import distance 8 | 9 | from prob_transformer.utils.config_init import cinit 10 | from prob_transformer.utils.handler.config import ConfigHandler 11 | from prob_transformer.model.probtransformer import ProbTransformer 12 | from prob_transformer.data.rna_handler import RNAHandler 13 | from prob_transformer.data.iterator import MyIterator 14 | from prob_transformer.evaluation.statistics_center import StatisticsCenter 15 | from prob_transformer.routine.evaluation import is_valid_structure,correct_invalid_structure, struct_to_mat 16 | 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser(description='Using the ProbTransformer to fold an RNA sequence.') 21 | parser.add_argument('-s', '--sequence', type=str, help='A RNA sequence as ACGU-string') 22 | parser.add_argument('-m', '--model', default="checkpoints/prob_transformer_final.pth", type=str, 23 | help='A checkpoint file for the model to use') 24 | parser.add_argument('-c', '--cnn_head', default="checkpoints/cnn_head_final.pth", type=str, 25 | help='A RNA sequence as ACGU-string') 26 | parser.add_argument('-e', '--evaluate', action='store_true', help='Evaluates model on the test set TS0') 27 | parser.add_argument('-d', '--rna_data', default="data/rna_data.plk", type=str, help='Path to rna dataframe') 28 | parser.add_argument('-t', '--test_data', default="data/TS0.plk", type=str, help='Path to test dataframe TS0') 29 | parser.add_argument('-r', '--rank', default="cuda", type=str, help='Device to infer the model, cuda or cpu') 30 | args = parser.parse_args() 31 | 32 | if args.cnn_head == "checkpoints/cnn_head_final.pth" and not os.path.exists("checkpoints/cnn_head_final.pth"): 33 | os.makedirs("checkpoints", exist_ok=True) 34 | print("Download CNN head checkpoint") 35 | wget.download("https://ml.informatik.uni-freiburg.de/research-artifacts/probtransformer/cnn_head_final.pth", "checkpoints/cnn_head_final.pth") 36 | 37 | if args.model == "checkpoints/prob_transformer_final.pth" and not os.path.exists("checkpoints/prob_transformer_final.pth"): 38 | os.makedirs("checkpoints", exist_ok=True) 39 | print("Download prob transformer checkpoint") 40 | wget.download("https://ml.informatik.uni-freiburg.de/research-artifacts/probtransformer/prob_transformer_final.pth", "checkpoints/prob_transformer_final.pth") 41 | 42 | 43 | transformer_checkpoint = torch.load(args.model, map_location=torch.device(args.rank)) 44 | 45 | cfg = ConfigHandler(config_dict=transformer_checkpoint['config']) 46 | 47 | rna_data = cinit(RNAHandler, cfg.data.rna.dict, df_path=args.rna_data, sub_set='valid', prob_training=True, 48 | device=args.rank, seed=cfg.data.seed, ignore_index=-1, similarity='80', exclude=[], max_length=500) 49 | 50 | seq_vocab_size = rna_data.seq_vocab_size 51 | trg_vocab_size = rna_data.struct_vocab_size 52 | 53 | model = cinit(ProbTransformer, cfg.model.dict, seq_vocab_size=seq_vocab_size, trg_vocab_size=trg_vocab_size, 54 | mat_config=None, mat_head=False, mat_input=False, prob_ff=False, 55 | scaffold=False, props=False).to(args.rank) 56 | model.load_state_dict(transformer_checkpoint['state_dict'], strict=False) 57 | model.eval() 58 | 59 | cnn_head = torch.load(args.cnn_head, map_location=torch.device(args.rank)) 60 | cnn_head.eval() 61 | 62 | if args.sequence is not None: 63 | print(f"Fold input sequence {args.sequence}") 64 | if sorted(set(args.sequence)) != ['A', 'C', 'G', 'U']: 65 | raise UserWarning(f"unknown symbols in sequence: {set(args.sequence).difference('A', 'C', 'G', 'U')}. Please only use ACGU") 66 | 67 | src_seq = torch.LongTensor([[rna_data.seq_stoi[s] for s in args.sequence]]).to(args.rank) 68 | src_len = torch.LongTensor([src_seq.shape[1]]).to(args.rank) 69 | 70 | raw_output, raw_latent = model(src_seq, src_len, infer_mean=True, output_latent=True) 71 | 72 | pred_dist = torch.nn.functional.one_hot(torch.argmax(raw_output, dim=-1), 73 | num_classes=raw_output.shape[-1]).to(torch.float).detach() 74 | pred_token = torch.argmax(raw_output, dim=-1).detach() 75 | 76 | b_pred_mat, mask = cnn_head(latent=raw_latent, src=src_seq, pred=pred_token, src_len=src_len) 77 | 78 | pred_dist = pred_dist[0, :, :].detach().cpu() 79 | pred_argmax = torch.argmax(pred_dist, keepdim=False, dim=-1).numpy().tolist() 80 | pred_struct = [rna_data.struct_itos[i] for i in pred_argmax] 81 | print("Predicted structure without CNN head:", pred_struct) 82 | if not is_valid_structure(pred_struct): 83 | pred_struct = correct_invalid_structure(pred_struct, pred_dist, rna_data.struct_stoi, src_seq.shape[1]) 84 | print("correction pred_struct", pred_struct) 85 | 86 | pred_mat = torch.sigmoid(b_pred_mat[0, :, :, 1]) 87 | pred_mat = torch.triu(pred_mat, diagonal=1).t() + torch.triu(pred_mat, diagonal=1) 88 | bindings_idx = np.where(pred_mat.cpu().detach().numpy() > 0.5) 89 | print("Predicted binding from CNN head, open :", bindings_idx[0].tolist()) 90 | print("Predicted binding from CNN head, close:", bindings_idx[1].tolist()) 91 | 92 | if args.evaluate: 93 | print("Evaluate on TS0") 94 | test_data = cinit(RNAHandler, cfg.data.rna.dict, df_path=args.test_data, sub_set='test', prob_training=True, 95 | device=args.rank, seed=cfg.data.seed, ignore_index=-1, similarity='80', exclude=[], max_length=500) 96 | 97 | data_iter = MyIterator(data_handler=test_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 98 | batching=False, pre_sort_samples=False, 99 | device=args.rank, seed=cfg.data.seed, ignore_index=-1) 100 | 101 | samples_count = 0 102 | metrics = defaultdict(list) 103 | evaluations = [] 104 | 105 | with torch.inference_mode(): 106 | for i, batch in tqdm(enumerate(data_iter)): 107 | raw_output, raw_latent = model(batch.src_seq, batch.src_len, infer_mean=True, output_latent=True) 108 | 109 | pred_dist = torch.nn.functional.one_hot(torch.argmax(raw_output, dim=-1), 110 | num_classes=raw_output.shape[-1]).to(torch.float).detach() 111 | pred_token = torch.argmax(raw_output, dim=-1).detach() 112 | 113 | b_pred_mat, mask = cnn_head(latent=raw_latent, src=batch.src_seq, pred=pred_token, 114 | src_len=batch.src_len) 115 | 116 | for b, length in enumerate(batch.src_len.detach().cpu().numpy()): 117 | 118 | samples_count += 1 119 | sequence = [test_data.seq_itos[i] for i in 120 | batch.src_seq[b, :length].detach().cpu().numpy()] 121 | 122 | pred_struct = pred_dist[b, :length, :].detach().cpu() 123 | true_struct = batch.trg_seq[b, :length].detach().cpu().numpy() 124 | 125 | pred_argmax = torch.argmax(pred_struct, keepdim=False, dim=-1).numpy() 126 | np_remove_pos = np.where(true_struct == -1)[0] 127 | np_pred_struct = np.delete(pred_argmax, np_remove_pos).tolist() 128 | np_true_struct = np.delete(true_struct, np_remove_pos).tolist() 129 | hamming = distance.hamming(np_pred_struct, np_true_struct) 130 | 131 | word_errors = [r != h for (r, h) in zip(pred_argmax, true_struct)] 132 | word_error_rate = sum(word_errors) / len(true_struct) 133 | 134 | metrics[f"hamming_distance"].append(hamming) 135 | metrics[f"word_error_rate"].append(word_error_rate) 136 | 137 | true_mat = batch.trg_pair_mat[b] 138 | true_mat = true_mat.detach().cpu().numpy() 139 | true_mat = true_mat[:length, :length] 140 | 141 | vocab = {k: i for i, k in enumerate(test_data.struct_vocab)} 142 | db_struct = [test_data.struct_itos[i] for i in pred_argmax] 143 | if is_valid_structure(db_struct): 144 | seq_pred_mat, pairs = struct_to_mat(db_struct) 145 | else: 146 | correct_struct = correct_invalid_structure(db_struct, pred_dist[b, :length, :], 147 | test_data.struct_stoi, length) 148 | if is_valid_structure(correct_struct): 149 | seq_pred_mat, pairs = struct_to_mat(correct_struct) 150 | else: 151 | seq_pred_mat = np.zeros_like(true_mat) 152 | 153 | metrics[f"seq_solved"].append(np.all(np.equal(true_mat, seq_pred_mat > 0.5)).astype(int)) 154 | pred_mat = torch.sigmoid(b_pred_mat[b, :length, :length, 1]) 155 | pred_mat = torch.triu(pred_mat, diagonal=1).t() + torch.triu(pred_mat, diagonal=1) 156 | sample = {"true": true_mat, "pred": pred_mat.detach().cpu().numpy(), "sequence": sequence, } 157 | evaluations.append(sample) 158 | 159 | metrics = {k: np.mean(v) for k, v in metrics.items()} 160 | stats = StatisticsCenter(evaluations, step_size=0.1, triangle_loss=False) 161 | metric = stats.eval_threshold(0.5) 162 | metrics.update({k: v for k, v in metric.items()}) 163 | metrics['samples'] = samples_count 164 | 165 | for key, value in metrics.items(): 166 | print(f"Evaluate {args.test_data} {key:20}: {value}") 167 | 168 | 169 | -------------------------------------------------------------------------------- /prob_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/__init__.py -------------------------------------------------------------------------------- /prob_transformer/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/data/__init__.py -------------------------------------------------------------------------------- /prob_transformer/data/dummy_handler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DummyHandler(): 5 | def __init__(self, 6 | data_samples, 7 | max_length, 8 | max_hamming=None, 9 | device='cpu'): 10 | 11 | data_samples = [sample for sample in data_samples if sample["length"] < max_length] 12 | if max_hamming != None: 13 | data_samples = [sample for sample in data_samples if sample["hamming"] < max_hamming] 14 | 15 | self.samples = data_samples 16 | self.max_length = max_length 17 | self.max_hamming = max_hamming 18 | 19 | self.set_size = len(data_samples) 20 | 21 | self.device = device 22 | 23 | self.ignore_index = -1 24 | 25 | def get_sample_by_index(self, index_iter): 26 | for index in index_iter: 27 | sample = self.samples[index] 28 | 29 | sample = self.prepare_sample(sample, self.max_length) 30 | 31 | yield sample 32 | 33 | def batch_sort_key(self, sample): 34 | return sample['length'] 35 | 36 | def pool_sort_key(self, sample): 37 | return sample['length'] 38 | 39 | def prepare_sample(self, input_sample, max_length=None): 40 | 41 | if 'sequence' in input_sample: 42 | del input_sample['sequence'] 43 | if 'hamming' in input_sample: 44 | del input_sample['hamming'] 45 | input_sample['length'] = torch.tensor([input_sample['length']]) 46 | 47 | return input_sample 48 | -------------------------------------------------------------------------------- /prob_transformer/data/iterator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | from types import SimpleNamespace 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class MyIterator(): 8 | def __init__(self, 9 | data_handler, 10 | batch_size, 11 | pool_size=20, 12 | pre_sort_samples=False, 13 | device='cpu', 14 | repeat=False, 15 | shuffle=True, 16 | batching=True, 17 | seed=1, 18 | ignore_index=-1, 19 | pad_index=0, 20 | ): 21 | 22 | self.repeat = repeat 23 | self.shuffle = shuffle 24 | self.batching = batching 25 | 26 | assert callable(getattr(data_handler, "batch_sort_key", None)), "data handler has no 'batch_sort_key' method" 27 | assert callable(getattr(data_handler, "pool_sort_key", None)), "data handler has no 'pool_sort_key' method" 28 | assert callable( 29 | getattr(data_handler, "get_sample_by_index", None)), "data handler has no 'get_sample_by_index' method" 30 | assert hasattr(data_handler, "set_size"), "data handler has no 'set_size' attribute" 31 | self.data_handler = data_handler 32 | 33 | self.batch_size = batch_size # batchsize in cumulated sequence length in batch 34 | self.pool_size = pool_size 35 | self.pre_sort_samples = pre_sort_samples 36 | self.device = device 37 | 38 | self.ignore_index = ignore_index 39 | self.pad_index = pad_index 40 | 41 | self.rng = np.random.default_rng(seed=seed) 42 | 43 | self.set_size = self.data_handler.set_size 44 | 45 | def get_index_list(self): 46 | index_list = [i for i in range(self.set_size)] 47 | if self.shuffle: 48 | self.rng.shuffle(index_list) 49 | return index_list 50 | 51 | def get_index_iter(self): 52 | while True: 53 | index_list = self.get_index_list() 54 | for i in index_list: 55 | yield i 56 | 57 | def cluster_index_iter(self, cluster_index_list): 58 | while True: 59 | for i in cluster_index_list: 60 | yield i 61 | 62 | def pool_and_sort(self, sample_iter): 63 | 64 | pool = [] 65 | 66 | for sample in sample_iter: 67 | if not self.pre_sort_samples: 68 | yield sample 69 | else: 70 | pool.append(sample) 71 | if len(pool) >= self.pool_size: 72 | pool.sort(key=self.data_handler.pool_sort_key) 73 | 74 | while len(pool) > 0: 75 | yield pool.pop() 76 | if len(pool) > 0: 77 | pool.sort(key=self.data_handler.pool_sort_key) 78 | while len(pool) > 0: 79 | yield pool.pop() 80 | 81 | def __iter__(self): 82 | 83 | minibatch, max_size_in_batch = [], 0 84 | 85 | while True: 86 | 87 | if self.repeat: 88 | index_iter = self.get_index_iter() 89 | else: 90 | index_iter = self.get_index_list() 91 | 92 | for sample in self.pool_and_sort(self.data_handler.get_sample_by_index(index_iter)): 93 | 94 | if self.batching: 95 | minibatch.append(sample) 96 | max_size_in_batch = max(max_size_in_batch, self.data_handler.batch_sort_key(sample)) 97 | size_so_far = len(minibatch) * max(max_size_in_batch, self.data_handler.batch_sort_key(sample)) 98 | if size_so_far == self.batch_size: 99 | yield self.batch_samples(minibatch) 100 | minibatch, max_size_in_batch = [], 0 101 | if size_so_far > self.batch_size: 102 | yield self.batch_samples(minibatch[:-1]) 103 | minibatch = minibatch[-1:] 104 | max_size_in_batch = self.data_handler.batch_sort_key(minibatch[0]) 105 | else: 106 | yield self.batch_samples([sample]) 107 | 108 | if not self.repeat: 109 | if self.batching and len(minibatch) > 0: 110 | yield self.batch_samples(minibatch) 111 | return 112 | 113 | def batch_samples(self, sample_dict_minibatch: List[Dict]): 114 | 115 | with torch.no_grad(): 116 | batch_dict = {k: [dic[k] for dic in sample_dict_minibatch] for k in sample_dict_minibatch[0]} 117 | 118 | for key, tensor_list in batch_dict.items(): 119 | 120 | max_shape = [list(i.shape) for i in tensor_list] 121 | if len(tensor_list[0].shape) == 0: 122 | max_shape = [len(tensor_list)] 123 | else: 124 | max_shape = [len(tensor_list)] + [max([s[l] for s in max_shape]) for l in range(len(max_shape[0]))] 125 | 126 | if tensor_list[0].dtype == torch.float64 or tensor_list[0].dtype == torch.float32 or tensor_list[ 127 | 0].dtype == torch.float16: 128 | max_tensor = torch.zeros(size=max_shape, dtype=tensor_list[0].dtype, device=self.device) 129 | 130 | elif tensor_list[0].dtype == torch.int64 or tensor_list[0].dtype == torch.int32 or tensor_list[ 131 | 0].dtype == torch.int16: 132 | if "trg_seq" == key or "trg_msa" == key: 133 | max_tensor = torch.ones(size=max_shape, dtype=tensor_list[0].dtype, 134 | device=self.device) * self.ignore_index 135 | else: 136 | max_tensor = torch.ones(size=max_shape, dtype=tensor_list[0].dtype, 137 | device=self.device) * self.pad_index 138 | else: 139 | raise UserWarning(f"key {key} has an unsupported dtype: {tensor_list[0].dtype}") 140 | 141 | for b, tensor in enumerate(tensor_list): 142 | ts = tensor.shape 143 | if len(tensor.shape) == 0: 144 | max_tensor[b] = tensor.to(self.device) 145 | elif len(tensor.shape) == 1: 146 | max_tensor[b, :ts[0]] = tensor.to(self.device) 147 | elif len(tensor.shape) == 2: 148 | max_tensor[b, :ts[0], :ts[1]] = tensor.to(self.device) 149 | elif len(tensor.shape) == 3: 150 | max_tensor[b, :ts[0], :ts[1], :ts[2]] = tensor.to(self.device) 151 | elif len(tensor.shape) == 4: 152 | max_tensor[b, :ts[0], :ts[1], :ts[2], :ts[3]] = tensor.to(self.device) 153 | else: 154 | raise UserWarning(f"key {key} has an unsupported dimension: {tensor_list[0].shape}") 155 | 156 | batch_dict[key] = max_tensor 157 | 158 | batch = SimpleNamespace(**batch_dict) 159 | 160 | return batch 161 | -------------------------------------------------------------------------------- /prob_transformer/data/mol_handler.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | import math 3 | from torch.utils.data import Dataset 4 | import re 5 | import pandas as pd 6 | 7 | import random 8 | import numpy as np 9 | import torch 10 | from torch.nn import functional as F 11 | from moses.utils import get_mol 12 | from rdkit import Chem 13 | 14 | 15 | def set_seed(seed): 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | 22 | def top_k_logits(logits, k): 23 | v, ix = torch.topk(logits, k) 24 | out = logits.clone() 25 | out[out < v[:, [-1]]] = -float('Inf') 26 | return out 27 | 28 | 29 | @torch.no_grad() 30 | def sample_batch(model, x, block_size, temperature=1.0, sample=False, top_k=None, props=None, ignore_index=0): 31 | """ 32 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 33 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 34 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 35 | of block_size, unlike an RNN that has an infinite context window. 36 | """ 37 | model.eval() 38 | 39 | steps = block_size 40 | 41 | if model.probabilistic: 42 | sample = False 43 | 44 | for k in range(steps): 45 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 46 | 47 | trg_len = torch.tensor([x_cond.shape[1]] * x_cond.shape[0], device=x.device) 48 | 49 | logits = model(trg_shf_seq=x_cond, trg_len=trg_len, props=props) # for liggpt 50 | 51 | logits = logits[:, -1, :] / temperature 52 | # optionally crop probabilities to only the top k options 53 | if top_k is not None: 54 | logits = top_k_logits(logits, top_k) 55 | # apply softmax to convert to probabilities 56 | probs = F.softmax(logits, dim=-1) 57 | # sample from the distribution or take the most likely 58 | if sample: 59 | ix = torch.multinomial(probs, num_samples=1) 60 | else: 61 | _, ix = torch.topk(probs, k=1, dim=-1) 62 | # append to the sequence and continue 63 | x = torch.cat((x, ix), dim=1) 64 | 65 | return x 66 | 67 | 68 | def check_novelty(gen_smiles, train_smiles): # gen: say 788, train: 120803 69 | if len(gen_smiles) == 0: 70 | novel_ratio = 0. 71 | else: 72 | duplicates = [1 for mol in gen_smiles if mol in train_smiles] # [1]*45 73 | novel = len(gen_smiles) - sum(duplicates) # 788-45=743 74 | novel_ratio = novel * 100. / len(gen_smiles) # 743*100/788=94.289 75 | print("novelty: {:.3f}%".format(novel_ratio)) 76 | return novel_ratio 77 | 78 | 79 | def canonic_smiles(smiles_or_mol): 80 | mol = get_mol(smiles_or_mol) 81 | if mol is None: 82 | return None 83 | return Chem.MolToSmiles(mol) 84 | 85 | 86 | class SmilesEnumerator(object): 87 | """SMILES Enumerator, vectorizer and devectorizer 88 | 89 | #Arguments 90 | charset: string containing the characters for the vectorization 91 | can also be generated via the .fit() method 92 | pad: Length of the vectorization 93 | leftpad: Add spaces to the left of the SMILES 94 | isomericSmiles: Generate SMILES containing information about stereogenic centers 95 | enum: Enumerate the SMILES during transform 96 | canonical: use canonical SMILES during transform (overrides enum) 97 | """ 98 | 99 | def __init__(self, charset='@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, 100 | canonical=False): 101 | self._charset = None 102 | self.charset = charset 103 | self.pad = pad 104 | self.leftpad = leftpad 105 | self.isomericSmiles = isomericSmiles 106 | self.enumerate = enum 107 | self.canonical = canonical 108 | 109 | @property 110 | def charset(self): 111 | return self._charset 112 | 113 | @charset.setter 114 | def charset(self, charset): 115 | self._charset = charset 116 | self._charlen = len(charset) 117 | self._char_to_int = dict((c, i) for i, c in enumerate(charset)) 118 | self._int_to_char = dict((i, c) for i, c in enumerate(charset)) 119 | 120 | def fit(self, smiles, extra_chars=[], extra_pad=5): 121 | """Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset 122 | 123 | #Arguments 124 | smiles: Numpy array or Pandas series containing smiles as strings 125 | extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present) 126 | extra_pad: Extra padding to add before or after the SMILES vectorization 127 | """ 128 | charset = set("".join(list(smiles))) 129 | self.charset = "".join(charset.union(set(extra_chars))) 130 | self.pad = max([len(smile) for smile in smiles]) + extra_pad 131 | 132 | def randomize_smiles(self, smiles): 133 | """Perform a randomization of a SMILES string 134 | must be RDKit sanitizable""" 135 | m = Chem.MolFromSmiles(smiles) 136 | ans = list(range(m.GetNumAtoms())) 137 | np.random.shuffle(ans) 138 | nm = Chem.RenumberAtoms(m, ans) 139 | return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles) 140 | 141 | def transform(self, smiles): 142 | """Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings 143 | #Arguments 144 | smiles: Numpy array or Pandas series containing smiles as strings 145 | """ 146 | one_hot = np.zeros((smiles.shape[0], self.pad, self._charlen), dtype=np.int8) 147 | 148 | if self.leftpad: 149 | for i, ss in enumerate(smiles): 150 | if self.enumerate: ss = self.randomize_smiles(ss) 151 | l = len(ss) 152 | diff = self.pad - l 153 | for j, c in enumerate(ss): 154 | one_hot[i, j + diff, self._char_to_int[c]] = 1 155 | return one_hot 156 | else: 157 | for i, ss in enumerate(smiles): 158 | if self.enumerate: ss = self.randomize_smiles(ss) 159 | for j, c in enumerate(ss): 160 | one_hot[i, j, self._char_to_int[c]] = 1 161 | return one_hot 162 | 163 | def reverse_transform(self, vect): 164 | """ Performs a conversion of a vectorized SMILES to a smiles strings 165 | charset must be the same as used for vectorization. 166 | #Arguments 167 | vect: Numpy array of vectorized SMILES. 168 | """ 169 | smiles = [] 170 | for v in vect: 171 | v = v[v.sum(axis=1) == 1] 172 | # Find one hot encoded index with argmax, translate to char and join to string 173 | smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1)) 174 | smiles.append(smile) 175 | return np.array(smiles) 176 | 177 | 178 | class SmileDataset(Dataset): 179 | 180 | def __init__(self, data, content, block_size, aug_prob=0, prop=None, device='cpu'): 181 | chars = sorted(list(set(content))) 182 | data_size, vocab_size = len(data), len(chars) 183 | print('data has %d smiles, %d unique characters.' % (data_size, vocab_size)) 184 | 185 | self.stoi = {ch: i for i, ch in enumerate(chars)} 186 | self.itos = {i: ch for i, ch in enumerate(chars)} 187 | self.max_len = block_size 188 | self.vocab_size = vocab_size 189 | self.data = data 190 | self.prop = prop 191 | self.debug = False 192 | self.tfm = SmilesEnumerator() 193 | self.aug_prob = aug_prob 194 | self.device = device 195 | 196 | def __len__(self): 197 | if self.debug: 198 | return math.ceil(len(self.data) / (self.max_len + 1)) 199 | else: 200 | return len(self.data) 201 | 202 | def __getitem__(self, idx): 203 | smiles = self.data[idx] 204 | 205 | if self.prop: 206 | prop = self.prop[idx] 207 | 208 | smiles = smiles.strip() 209 | 210 | p = np.random.uniform() 211 | if p < self.aug_prob: 212 | smiles = self.tfm.randomize_smiles(smiles) 213 | 214 | pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 215 | regex = re.compile(pattern) 216 | 217 | smiles += str('<') * (self.max_len - len(regex.findall(smiles))) 218 | if len(regex.findall(smiles)) > self.max_len: 219 | smiles = smiles[:self.max_len] 220 | 221 | smiles = regex.findall(smiles) 222 | dix = [self.stoi[s] for s in smiles] 223 | 224 | x = torch.tensor(dix[:-1], dtype=torch.long) 225 | y = torch.tensor(dix[1:], dtype=torch.long) 226 | if self.prop: 227 | prop = torch.tensor([prop], dtype=torch.float) 228 | else: 229 | prop = torch.tensor([0], dtype=torch.float) 230 | torch_length = torch.LongTensor([x.shape[0]])[0] 231 | 232 | return { 233 | "trg_shf_seq": x.to(self.device), 234 | "trg_seq": y.to(self.device), 235 | "post_seq": y.to(self.device), 236 | "trg_len": torch_length.to(self.device), 237 | "props": prop.to(self.device), 238 | "src_len": torch_length.to(self.device), 239 | } 240 | 241 | 242 | class MolHandler(): 243 | def __init__( 244 | self, 245 | data_dir: str = "mol_data", 246 | split: str = "valid", 247 | props: List = [], 248 | min_length=10, 249 | max_length=100, 250 | seed: int = 1, 251 | device='cpu', 252 | ): 253 | self.rng = np.random.RandomState(seed) 254 | 255 | self.max_length = max_length 256 | self.props = props 257 | self.split = split 258 | self.device = device 259 | 260 | data = pd.read_csv(data_dir) 261 | 262 | data.columns = data.columns.str.lower() 263 | data = data[data['smiles'].apply(lambda x: min_length <= len(x) <= max_length)] 264 | data = data.dropna(axis=0).reset_index(drop=True) 265 | 266 | self.data = data 267 | 268 | if split == "generate": 269 | set_data = data[data['source'] != 'test'].reset_index(drop=True) 270 | elif split == "train": 271 | set_data = data[data['source'] == 'train'].reset_index(drop=True) 272 | elif split == "valid": 273 | set_data = data[data['source'] == 'val'].reset_index(drop=True) 274 | elif split == "test": 275 | set_data = data[data['source'] == 'test'].reset_index(drop=True) 276 | 277 | smiles = set_data['smiles'] 278 | 279 | if props: 280 | prop = set_data[props].values.tolist() 281 | self.num_props = len(props) 282 | else: 283 | prop = [] 284 | self.num_props = False 285 | 286 | self.content = ' '.join(smiles) 287 | self.context = "C" 288 | 289 | pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" 290 | regex = re.compile(pattern) 291 | 292 | lens = [len(regex.findall(i.strip())) for i in (list(smiles.values))] 293 | max_len = max(lens) 294 | 295 | whole_string = ['#', '%10', '%11', '%12', '(', ')', '-', '1', '2', '3', '4', '5', '6', '7', '8', '9', '<', '=', 296 | 'B', 297 | 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', '[B-]', '[BH-]', '[BH2-]', '[BH3-]', '[B]', 298 | '[C+]', 299 | '[C-]', '[CH+]', '[CH-]', '[CH2+]', '[CH2]', '[CH]', '[F+]', '[H]', '[I+]', '[IH2]', '[IH]', 300 | '[N+]', 301 | '[N-]', '[NH+]', '[NH-]', '[NH2+]', '[NH3+]', '[N]', '[O+]', '[O-]', '[OH+]', '[O]', '[P+]', 302 | '[PH+]', '[PH2+]', '[PH]', '[S+]', '[S-]', '[SH+]', '[SH]', '[Se+]', '[SeH+]', '[SeH]', '[Se]', 303 | '[Si-]', '[SiH-]', '[SiH2]', '[SiH]', '[Si]', '[b-]', '[bH-]', '[c+]', '[c-]', '[cH+]', '[cH-]', 304 | '[n+]', '[n-]', '[nH+]', '[nH]', '[o+]', '[s+]', '[sH+]', '[se+]', '[se]', 'b', 'c', 'n', 'o', 305 | 'p', 's'] 306 | 307 | self.stoi = {ch: i for i, ch in enumerate(whole_string)} 308 | self.itos = {i: ch for ch, i in self.stoi.items()} 309 | 310 | self.dataset = SmileDataset(smiles, whole_string, max_len, prop=prop, aug_prob=0, device=self.device) 311 | 312 | self.set_size = set_data.shape[0] 313 | self.ignore_index = self.stoi['<'] 314 | 315 | self.vocab_size = len(whole_string) 316 | self.source_vocab_size = len(whole_string) 317 | self.target_vocab_size = len(whole_string) 318 | 319 | def __len__(self): 320 | """Allows to call len(this_dataset).""" 321 | return self.set_size 322 | 323 | def __getitem__(self, idx): 324 | """Allows to access samples with bracket notation""" 325 | return self.dataset.__getitem__(idx) 326 | 327 | def get_sample_by_index(self, index_iter): 328 | for index in index_iter: 329 | sample = self.dataset.__getitem__(index) 330 | yield sample 331 | 332 | def batch_sort_key(self, sample): 333 | return sample['trg_len'].detach() 334 | 335 | def pool_sort_key(self, sample): 336 | return sample['trg_len'].item() 337 | -------------------------------------------------------------------------------- /prob_transformer/data/rna_handler.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import torch 4 | import pandas as pd 5 | 6 | from pathlib import Path 7 | 8 | 9 | class RNAHandler(): 10 | def __init__(self, 11 | df_path, 12 | sub_set, 13 | ignore_index, 14 | seed, 15 | min_length, 16 | max_length, 17 | similarity=80, 18 | device='cpu', 19 | ): 20 | 21 | assert sub_set in ['train', 'valid', 'test'] 22 | 23 | df_path = Path(df_path) 24 | 25 | if not df_path.is_file(): 26 | raise UserWarning(f"no dataframe found on: {df_path.resolve().__str__()}") 27 | 28 | df = pd.read_pickle(df_path) 29 | 30 | df = df[df[f"non_sim_{similarity}"]] 31 | df = df[df['set'].str.contains(sub_set)] 32 | 33 | df = df[df['structure'].apply(set).apply(len) > 1] # remove only '.' samples, should be removed already 34 | 35 | self.max_length = max_length 36 | 37 | df = df[df['sequence'].apply(lambda x: min_length <= len(x) <= max_length)] 38 | 39 | df = df.reset_index() 40 | 41 | self.datasettoint = {k: v for k, v in zip(df['dataset'].unique(), range(len(df['dataset'].unique())))} 42 | self.inttodataset = {v: k for k, v in self.datasettoint.items()} 43 | 44 | self.df = df 45 | 46 | self.set_size = self.df.shape[0] 47 | 48 | self.rng = np.random.default_rng(seed=seed) 49 | self.device = device 50 | 51 | self.ignore_index = ignore_index 52 | 53 | self.seq_vocab = ['A', 'C', 'G', 'U', 'N'] 54 | self.canonical_pairs = ['GC', 'CG', 'AU', 'UA', 'GU', 'UG'] 55 | 56 | nucs = { 57 | 'T': 'U', 58 | 'P': 'U', 59 | 'R': 'A', # or 'G' 60 | 'Y': 'C', # or 'T' 61 | 'M': 'C', # or 'A' 62 | 'K': 'U', # or 'G' 63 | 'S': 'C', # or 'G' 64 | 'W': 'U', # or 'A' 65 | 'H': 'C', # or 'A' or 'U' 66 | 'B': 'U', # or 'G' or 'C' 67 | 'V': 'C', # or 'G' or 'A' 68 | 'D': 'A', # or 'G' or 'U' 69 | } 70 | 71 | self.struct_vocab = ['.', '(0c', ')0c', '(1c', ')1c', '(2c', ')2c', '(0nc', ')0nc', '(1nc', ')1nc', '(2nc', 72 | ')2nc'] 73 | 74 | self.seq_stoi = dict(zip(self.seq_vocab, range(len(self.seq_vocab)))) 75 | self.seq_itos = dict((y, x) for x, y in self.seq_stoi.items()) 76 | 77 | for nuc, mapping in nucs.items(): 78 | self.seq_stoi[nuc] = self.seq_stoi[mapping] 79 | 80 | self.struct_itos = dict(zip(range(len(self.struct_vocab)), self.struct_vocab)) 81 | self.struct_stoi = dict((y, x) for x, y in self.struct_itos.items()) 82 | 83 | self.seq_vocab_size = len(self.seq_vocab) 84 | self.struct_vocab_size = len(self.struct_vocab) 85 | 86 | def get_sample_by_index(self, index_iter): 87 | for index in index_iter: 88 | sample = self.df.iloc[index] 89 | 90 | sample = self.prepare_sample(sample, self.max_length) 91 | 92 | yield sample 93 | 94 | def batch_sort_key(self, sample): 95 | return sample['src_len'].detach().tolist() 96 | 97 | def pool_sort_key(self, sample): 98 | return sample['src_len'].item() 99 | 100 | def sequence2index_matrix(self, sequence, mapping): 101 | 102 | int_sequence = list(map(mapping.get, sequence)) 103 | 104 | if self.device == 'cpu': 105 | tensor = torch.LongTensor(int_sequence) 106 | else: 107 | tensor = torch.cuda.LongTensor(int_sequence, device=self.device) 108 | return tensor 109 | 110 | def prepare_sample(self, input_sample, max_length=None): 111 | 112 | sequence = input_sample["sequence"] 113 | structure = input_sample["structure"] 114 | pos1id = input_sample["pos1id"] 115 | pos2id = input_sample["pos2id"] 116 | if 'is_pdb' in input_sample: 117 | pdb_sample = int(input_sample['is_pdb']) 118 | else: 119 | pdb_sample = 0 120 | dataset = self.datasettoint[input_sample['dataset']] 121 | 122 | length = len(sequence) 123 | 124 | with torch.no_grad(): 125 | pair_m, pair_mat = self.get_pair_matrices(pos1id, pos2id, length, pdb_sample) 126 | 127 | target_structure = self.encode_target_structure(pair_mat, structure, sequence) 128 | 129 | src_seq = self.sequence2index_matrix(sequence, self.seq_stoi) 130 | trg_seq = self.sequence2index_matrix(target_structure, self.struct_stoi) 131 | 132 | post_seq = trg_seq.clone() 133 | 134 | post_pair_m = pair_m.clone() 135 | post_pair_mat = pair_mat.clone() 136 | 137 | trg_pair_m = pair_m.clone() 138 | trg_pair_mat = pair_mat.clone() 139 | 140 | if pdb_sample == 0: 141 | trg_pair_m.fill_(self.ignore_index) 142 | 143 | if self.device == 'cpu': 144 | torch_length = torch.LongTensor([length])[0] 145 | torch_pdb_sample = torch.LongTensor([pdb_sample])[0] 146 | torch_dataset = torch.LongTensor([dataset])[0] 147 | else: 148 | torch_length = torch.cuda.LongTensor([length], device=self.device)[0] 149 | torch_pdb_sample = torch.cuda.LongTensor([pdb_sample], device=self.device)[0] 150 | torch_dataset = torch.cuda.LongTensor([dataset], device=self.device)[0] 151 | 152 | torch_sample = {} 153 | 154 | torch_sample['src_seq'] = src_seq 155 | 156 | torch_sample['src_len'] = torch_length 157 | torch_sample['trg_len'] = torch_length 158 | torch_sample['pdb_sample'] = torch_pdb_sample 159 | torch_sample['dataset'] = torch_dataset 160 | 161 | torch_sample['trg_seq'] = trg_seq 162 | torch_sample['trg_pair_m'] = trg_pair_m 163 | torch_sample['trg_pair_mat'] = trg_pair_mat 164 | 165 | torch_sample['post_seq'] = post_seq 166 | torch_sample['post_pair_m'] = post_pair_m 167 | torch_sample['post_pair_mat'] = post_pair_mat 168 | 169 | return torch_sample 170 | 171 | def encode_target_structure(self, pair_mat, raw_structure, sequence): 172 | 173 | pos1, pos2 = torch.where(pair_mat == 1) 174 | pos = torch.concat([pos1, pos2]).unique().cpu().numpy() 175 | 176 | pos_dict = {i1.item(): i2.item() for i1, i2 in zip(pos1, pos2)} 177 | 178 | structure = [] 179 | for s_idx, s in enumerate(raw_structure): 180 | if len(s) > 1: 181 | if s[1] != '0' and s[1] != '1': 182 | s = s[0] + '2' 183 | 184 | if s != '.': 185 | if s[0] == "(": 186 | counter_idx = pos_dict[s_idx] 187 | assert raw_structure[counter_idx][0] == ')' 188 | pair = sequence[s_idx] + sequence[counter_idx] 189 | if pair in self.canonical_pairs: 190 | s = s + "c" 191 | else: 192 | s = s + "nc" 193 | elif s[0] == ")": 194 | counter_idx = pos_dict[s_idx] 195 | assert raw_structure[counter_idx][0] == '(' 196 | pair = sequence[counter_idx] + sequence[s_idx] 197 | if pair in self.canonical_pairs: 198 | s = s + "c" 199 | else: 200 | s = s + "nc" 201 | else: 202 | raise UserWarning("unknown ()") 203 | structure.append(s) 204 | structure = np.asarray(structure) 205 | 206 | if "." in structure[pos]: 207 | print("DEBUG") 208 | 209 | assert "." not in structure[pos] 210 | 211 | return structure.tolist() 212 | 213 | def get_pair_matrices(self, pos1id, pos2id, length, pdb_sample): 214 | 215 | assert len(pos1id) == len(pos2id) 216 | 217 | if self.device == 'cpu': 218 | multi_mat = torch.LongTensor(length, length).fill_(0) 219 | pair_mat = torch.LongTensor(length, length).fill_(0) 220 | else: 221 | multi_mat = torch.cuda.LongTensor(length, length, device=self.device).fill_(0) 222 | pair_mat = torch.cuda.LongTensor(length, length, device=self.device).fill_(0) 223 | 224 | if pdb_sample == 1: 225 | pos_count = collections.Counter(pos1id + pos2id) 226 | multiplets = [pos for pos, count in pos_count.items() if count > 1] 227 | else: 228 | multiplets = [] 229 | 230 | for p1, p2 in zip(pos1id, pos2id): 231 | 232 | pair_mat[p1, p2] = 1 233 | pair_mat[p2, p1] = 1 234 | 235 | if len(multiplets) > 0: 236 | if p1 in multiplets or p2 in multiplets: 237 | multi_mat[p1, p2] = 1 238 | multi_mat[p2, p1] = 1 239 | 240 | return multi_mat, pair_mat 241 | -------------------------------------------------------------------------------- /prob_transformer/data/ssd_handler.py: -------------------------------------------------------------------------------- 1 | import string, os 2 | from types import SimpleNamespace 3 | from collections import OrderedDict 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class SSDHandler(): 9 | 10 | def __init__(self, max_len, min_len, sample_amount, trg_vocab_size, src_vocab_size, sentence_len, n_sentence, 11 | sentence_variations, 12 | seed=123, device='cpu', pre_src_vocab=None, pre_trg_vocab=None, token_dict=None): 13 | 14 | super().__init__() 15 | self.rng = np.random.RandomState(seed=seed) 16 | 17 | self.set_size = sample_amount 18 | 19 | self.blank_word = '' 20 | 21 | self.sentence_len = sentence_len 22 | self.sentence_variations = sentence_variations 23 | self.min_len = min_len 24 | self.max_len = max_len 25 | 26 | self.device = device 27 | 28 | if pre_src_vocab == None: 29 | self.pre_src_vocab = self._make_vocab(src_vocab_size) 30 | self.pre_trg_vocab = self._make_vocab(trg_vocab_size, numbers=True) 31 | self.token_dict = self._make_token_dict(sentence_len, n_sentence, sentence_variations) 32 | else: 33 | self.pre_src_vocab = pre_src_vocab 34 | self.pre_trg_vocab = pre_trg_vocab 35 | self.token_dict = token_dict 36 | 37 | self.source_stoi = dict(zip(self.pre_src_vocab, range(len(self.pre_src_vocab)))) 38 | self.source_itos = dict((y, x) for x, y in self.source_stoi.items()) 39 | 40 | self.target_itos = dict(zip(range(len(self.pre_trg_vocab)), self.pre_trg_vocab)) 41 | self.target_stoi = dict((y, x) for x, y in self.target_itos.items()) 42 | 43 | self.source_vocab_size = len(self.source_stoi) 44 | self.target_vocab_size = len(self.target_stoi) 45 | 46 | os.makedirs("cache", exist_ok=True) 47 | file_name = f"cache/{min_len}_{max_len}_{sentence_variations}_{sentence_len}_{sample_amount}_{seed}.tlist" 48 | 49 | if not os.path.isfile(file_name): 50 | self.sample_list = self._generate_data(size=sample_amount) 51 | torch.save(self.sample_list, file_name) 52 | else: 53 | self.sample_list = torch.load(file_name) 54 | 55 | self.trg_vocab = SimpleNamespace(**{"id_to_token": lambda i: self.target_itos[i], 56 | "token_to_id": lambda i: self.target_stoi[i], 57 | "size": self.target_vocab_size}) 58 | 59 | def get_sample_by_index(self, index_iter): 60 | for index in index_iter: 61 | sample = self.sample_list[index] 62 | 63 | yield sample 64 | 65 | def batch_sort_key(self, sample): 66 | return sample['src_len'].detach().tolist() 67 | 68 | def pool_sort_key(self, sample): 69 | return sample['src_len'].item() 70 | 71 | @staticmethod 72 | def _softmax(x): 73 | e_x = np.exp(x - np.max(x)) 74 | return e_x / e_x.sum() 75 | 76 | def _make_vocab(self, vocab_size, numbers=False): 77 | vocab = [] 78 | word_size = 5 79 | 80 | while len(vocab) < vocab_size: 81 | if numbers: 82 | word = "".join([str(s) for s in self.rng.randint(0, 10, word_size)]) 83 | else: 84 | word = "".join(self.rng.choice([s for s in string.ascii_letters], word_size, replace=True)) 85 | if word not in vocab: 86 | vocab.append(word) 87 | 88 | if numbers: 89 | pass 90 | else: 91 | vocab.append(self.blank_word) 92 | 93 | return vocab 94 | 95 | def _make_token_dict(self, sentence_len, n_sentence, sentence_variations): 96 | 97 | token_dict = OrderedDict() 98 | for _ in range(int(min([n_sentence, len(self.pre_src_vocab) ** sentence_len, 99 | len(self.pre_trg_vocab) ** sentence_len]))): 100 | 101 | # create new src token 102 | src_list = self.rng.choice(self.pre_src_vocab, sentence_len, replace=True).tolist() 103 | src_token = "-".join(src_list) 104 | while src_token in token_dict.keys(): 105 | src_list = self.rng.choice(self.pre_src_vocab, sentence_len, replace=True).tolist() 106 | src_token = "-".join(src_list) 107 | 108 | # create trg 109 | trg_options_list = [] 110 | trg_dist_list = [] 111 | 112 | sample_sentence_len = sentence_len 113 | trg_choice = self.pre_trg_vocab.__len__() 114 | 115 | for sub in range(sample_sentence_len): 116 | rand_options = self.rng.randint(1, 1 + sentence_variations) 117 | trg_options = self.rng.choice(trg_choice, rand_options, replace=False).tolist() 118 | trg_dist = self._softmax(self.rng.uniform(0, 2, rand_options)) 119 | trg_options_list.append(trg_options) 120 | trg_dist_list.append(trg_dist) 121 | 122 | token_dict[src_token] = {"src_list": src_list, "trg_options_list": trg_options_list, 123 | "trg_dist_list": trg_dist_list} 124 | return token_dict 125 | 126 | def _generate_data(self, size): 127 | 128 | data_set = [] 129 | for idx in range(int(size)): 130 | if self.min_len == self.max_len: 131 | length = self.min_len 132 | else: 133 | length = self.rng.randint(self.min_len, self.max_len + 1) 134 | length = (length // self.sentence_len) * self.sentence_len 135 | source, target = self.make_sample(length) 136 | 137 | if self.device == 'cpu': 138 | torch_length = torch.LongTensor([length])[0] 139 | else: 140 | torch_length = torch.cuda.LongTensor([length], device=self.device)[0] 141 | 142 | torch_sample = {} 143 | src_seq = self.sequence2index_matrix(source, self.source_stoi) 144 | torch_sample['src_seq'] = src_seq 145 | torch_sample['src_len'] = torch_length 146 | 147 | trg_seq = self.sequence2index_matrix(target, self.target_stoi) 148 | 149 | torch_sample['trg_seq'] = trg_seq 150 | torch_sample['trg_len'] = torch_length 151 | torch_sample['post_seq'] = trg_seq 152 | 153 | data_set.append(torch_sample) 154 | 155 | return data_set 156 | 157 | def sequence2index_matrix(self, sequence, mapping): 158 | 159 | int_sequence = list(map(mapping.get, sequence)) 160 | 161 | if self.device == 'cpu': 162 | tensor = torch.LongTensor(int_sequence) 163 | else: 164 | tensor = torch.cuda.LongTensor(int_sequence, device=self.device) 165 | return tensor 166 | 167 | def make_sample(self, n_steps): 168 | 169 | n_sub_token = int(n_steps / self.sentence_len) 170 | source, target = [], [] 171 | 172 | for _ in range(n_sub_token): 173 | src_sub_token = self.rng.choice(list(self.token_dict.keys())) 174 | 175 | sub_token = self.token_dict[src_sub_token] 176 | source.extend(sub_token['src_list']) 177 | 178 | for trg_options, trg_dist in zip(sub_token['trg_options_list'], sub_token['trg_dist_list']): 179 | trg_idx = self.rng.choice(trg_options, 1, p=trg_dist)[0] 180 | trg_symbol = self.pre_trg_vocab[trg_idx] 181 | target.append(trg_symbol) 182 | 183 | return source, target 184 | 185 | def get_valid_data(self): 186 | 187 | source = [] 188 | for sub_token in self.token_dict.values(): 189 | source.extend(sub_token['src_list']) 190 | 191 | target_dist = self.get_sample_dist(source) 192 | 193 | return source, target_dist 194 | 195 | def get_sample_dist(self, src): 196 | 197 | if isinstance(src[0], int) or isinstance(src[0], np.int32) or isinstance(src[0], np.int64): 198 | src = [self.source_itos[s] for s in src] 199 | 200 | elif isinstance(src[0], torch.Tensor): 201 | src = src.detach().cpu().tolist() 202 | src = [self.source_itos[s] for s in src] 203 | else: 204 | raise UserWarning(f"unknown source type: {type(src[0])}") 205 | 206 | target_dist = [] 207 | for idx in range(len(src) // self.sentence_len): 208 | sub = src[idx * self.sentence_len: idx * self.sentence_len + self.sentence_len] 209 | sub_name = '-'.join(sub) 210 | if sub_name != '-'.join([self.blank_word] * self.sentence_len): 211 | sub_token = self.token_dict[sub_name] 212 | for tidx, (trg_options, trg_dist) in enumerate( 213 | zip(sub_token['trg_options_list'], sub_token['trg_dist_list'])): 214 | t_dist = np.zeros([self.target_vocab_size]) 215 | for trg_pos, t_prob in zip(trg_options, trg_dist): 216 | trg_symb = self.pre_trg_vocab[trg_pos] 217 | trg_pos = self.target_stoi[trg_symb] 218 | t_dist[trg_pos] = t_prob 219 | target_dist.append(t_dist) 220 | target_dist = np.stack(target_dist, axis=0) 221 | 222 | return target_dist 223 | 224 | def get_batch_dist(self, src_batch): 225 | trg_dist_list = [] 226 | for b in range(src_batch.size()[0]): 227 | src = src_batch[b, :] 228 | trg_dist = self.get_sample_dist(src) 229 | trg_dist_list.append(trg_dist) 230 | 231 | trg_dist_len = [d.shape[0] for d in trg_dist_list] 232 | if np.unique(trg_dist_len).shape[0] == 1: 233 | trg_dist = np.stack(trg_dist_list, axis=0) 234 | else: 235 | trg_dist = np.zeros([len(trg_dist_list), max(trg_dist_len), self.target_vocab_size]) 236 | for b, trg_d in enumerate(trg_dist_list): 237 | trg_dist[b, :trg_d.shape[0], :] = trg_d 238 | 239 | trg_dist = torch.FloatTensor(trg_dist).to(src_batch.device) 240 | return trg_dist 241 | -------------------------------------------------------------------------------- /prob_transformer/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/evaluation/__init__.py -------------------------------------------------------------------------------- /prob_transformer/evaluation/cnn_head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/evaluation/cnn_head/__init__.py -------------------------------------------------------------------------------- /prob_transformer/evaluation/cnn_head/infere_transformer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import torch 4 | import distance 5 | 6 | from prob_transformer.utils.config_init import cinit 7 | from prob_transformer.utils.handler.config import ConfigHandler 8 | from prob_transformer.model.probtransformer import ProbTransformer 9 | from prob_transformer.data.rna_handler import RNAHandler 10 | from prob_transformer.data.iterator import MyIterator 11 | 12 | 13 | def infer_rna_transformer(checkpoint): 14 | rank = 0 if torch.cuda.is_available() else "cpu" 15 | cfg = ConfigHandler(config_dict=checkpoint['config']) 16 | 17 | ignore_index = -1 18 | train_data = cinit(RNAHandler, cfg.data.rna.dict, df_path='data/rna_data.plk', 19 | sub_set='train', prob_training=True, device=rank, seed=cfg.data.seed, ignore_index=ignore_index, 20 | similarity='80', 21 | exclude=[], max_length=500) 22 | valid_data = cinit(RNAHandler, cfg.data.rna.dict, df_path='data/rna_data.plk', 23 | sub_set='valid', prob_training=True, device=rank, seed=cfg.data.seed, ignore_index=ignore_index, 24 | similarity='80', 25 | exclude=[], max_length=500) 26 | test_data = cinit(RNAHandler, cfg.data.rna.dict, df_path='data/rna_data.plk', 27 | sub_set='test', prob_training=True, device=rank, seed=cfg.data.seed, ignore_index=ignore_index, 28 | similarity='80', 29 | exclude=[], max_length=500) 30 | 31 | seq_vocab_size = valid_data.seq_vocab_size 32 | trg_vocab_size = valid_data.struct_vocab_size 33 | 34 | train_iter = MyIterator(data_handler=train_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=True, 35 | batching=True, pre_sort_samples=True, 36 | device=rank, seed=cfg.data.seed, ignore_index=ignore_index) 37 | 38 | valid_iter = MyIterator(data_handler=valid_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 39 | batching=True, pre_sort_samples=False, 40 | device=rank, seed=cfg.data.seed, ignore_index=ignore_index) 41 | 42 | test_iter = MyIterator(data_handler=test_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 43 | batching=False, pre_sort_samples=False, 44 | device=rank, seed=cfg.data.seed, ignore_index=ignore_index) 45 | 46 | model = cinit(ProbTransformer, cfg.model.dict, seq_vocab_size=seq_vocab_size, trg_vocab_size=trg_vocab_size, 47 | mat_config=None, mat_head=False, mat_input=False, prob_ff=False, 48 | scaffold=False, props=False).to(rank) 49 | 50 | model.load_state_dict(checkpoint['state_dict'], strict=False) 51 | 52 | model.eval() 53 | 54 | data_iter = {'train': train_iter, 'valid': valid_iter, 'test': test_iter} 55 | 56 | for d_set, d_iter in data_iter.items(): 57 | raw_output_dump = [] 58 | with torch.inference_mode(): 59 | for i, batch in tqdm(enumerate(d_iter)): 60 | 61 | model.eval() 62 | 63 | with torch.no_grad(): 64 | raw_output, raw_latent = model(batch.src_seq, batch.src_len, infer_mean=True, output_latent=True) 65 | 66 | trg_dist = torch.nn.functional.one_hot(torch.argmax(raw_output, dim=-1), 67 | num_classes=raw_output.shape[-1]).to(torch.float).detach() 68 | trg_token = torch.argmax(raw_output, dim=-1).detach() 69 | 70 | for b, length in enumerate(batch.src_len.detach().cpu().numpy()): 71 | sequence = [d_iter.data_handler.seq_itos[i] for i in 72 | batch.src_seq[b, :length].detach().cpu().numpy()] 73 | 74 | pred_struct = trg_dist[b, :length, :].detach().cpu() 75 | true_struct = batch.trg_seq[b, :length].detach().cpu() 76 | 77 | pred_argmax = torch.argmax(pred_struct, keepdim=False, dim=-1) 78 | np_remove_pos = np.where(true_struct.numpy() == -1)[0] 79 | np_pred_struct = np.delete(pred_argmax.numpy(), np_remove_pos).tolist() 80 | np_true_struct = np.delete(true_struct.numpy(), np_remove_pos).tolist() 81 | hamming = distance.hamming(np_pred_struct, np_true_struct) 82 | 83 | raw_output_dump.append({ 84 | "raw_output": raw_output[b, :length, :].detach().cpu(), 85 | "raw_latent": raw_latent[b, :length, :].detach().cpu(), 86 | "pred_struct": trg_token[b, :length].detach().cpu(), 87 | "true_struct": true_struct, 88 | "true_mat": batch.trg_pair_mat[b, :length, :length].detach().cpu(), 89 | "src_seq": batch.src_seq[b, :length].detach().cpu(), 90 | "trg_token": trg_token[b, :length].detach().cpu(), 91 | "sequence": sequence, 92 | "hamming": hamming, 93 | "length": length 94 | }) 95 | 96 | torch.save(raw_output_dump, cfg.expt.experiment_dir / f"model_inference_{d_set}.pth") 97 | 98 | 99 | if __name__ == "__main__": 100 | import argparse 101 | 102 | parser = argparse.ArgumentParser(description='Evaluate the model given a checkpoint file.') 103 | parser.add_argument('-c', '--checkpoint', type=str, help='a checkpoint file') 104 | 105 | args = parser.parse_args() 106 | 107 | checkpoint = torch.load(args.checkpoint, 108 | map_location=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) 109 | 110 | infer_rna_transformer(checkpoint=checkpoint) 111 | -------------------------------------------------------------------------------- /prob_transformer/evaluation/cnn_head/train_cnn_from_inference.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | 7 | from prob_transformer.utils.torch_utils import count_parameters 8 | from prob_transformer.utils.config_init import cinit 9 | from prob_transformer.utils.handler.config import ConfigHandler 10 | 11 | from prob_transformer.module.mat_head import SimpleMatrixHead 12 | from prob_transformer.data.rna_handler import RNAHandler 13 | from prob_transformer.data.dummy_handler import DummyHandler 14 | from prob_transformer.data.iterator import MyIterator 15 | from prob_transformer.evaluation.statistics_center import StatisticsCenter 16 | from prob_transformer.module.optim_builder import OptiMaster 17 | 18 | from prob_transformer.utils.supporter import Supporter 19 | 20 | 21 | def run_cnn_training(config, expt_dir): 22 | model_dir = pathlib.Path(config['model_dir']) 23 | model_file = config['model_file'] 24 | 25 | sup = Supporter(experiments_dir=expt_dir, config_dict=config, count_expt=True) 26 | cfg = sup.get_config() 27 | log = sup.get_logger() 28 | 29 | ckp = sup.ckp 30 | 31 | np.random.seed(cfg.seed) 32 | torch.manual_seed(cfg.seed) 33 | 34 | rank = 0 if torch.cuda.is_available() else "cpu" 35 | 36 | checkpoint = torch.load(model_dir / model_file, 37 | map_location=torch.device('cuda', rank) if rank == 0 else torch.device('cpu')) 38 | 39 | cfg_ckp = ConfigHandler(config_dict=checkpoint['config']) 40 | 41 | ignore_index = -1 42 | 43 | train_data = torch.load(model_dir / f"model_inference_train.pth") 44 | valid_data = torch.load(model_dir / f"model_inference_valid.pth") 45 | test_data = torch.load(model_dir / f"model_inference_test.pth") 46 | 47 | train_data = DummyHandler(train_data, max_length=cfg.max_train_len, device=0, max_hamming=5) 48 | valid_data = DummyHandler(valid_data, max_length=500, device=0) 49 | test_data = DummyHandler(test_data, max_length=500, device=0) 50 | 51 | rna_data = cinit(RNAHandler, cfg_ckp.data.rna.dict, df_path='data/rna_data.plk', 52 | sub_set='valid', prob_training=True, device='cpu', seed=cfg_ckp.data.seed, 53 | ignore_index=ignore_index, max_length=500, similarity='80', exclude=[]) 54 | 55 | train_iter = MyIterator(data_handler=train_data, batch_size=cfg.batch_size, repeat=True, shuffle=True, 56 | batching=True, pre_sort_samples=False, 57 | device=rank, seed=cfg_ckp.data.seed, ignore_index=ignore_index) 58 | 59 | valid_iter = MyIterator(data_handler=valid_data, batch_size=1000, repeat=False, shuffle=False, 60 | batching=True, pre_sort_samples=False, 61 | device=rank, seed=cfg_ckp.data.seed, ignore_index=ignore_index) 62 | 63 | test_iter = MyIterator(data_handler=test_data, batch_size=1000, repeat=False, shuffle=False, 64 | batching=True, pre_sort_samples=False, 65 | device=rank, seed=cfg_ckp.data.seed, ignore_index=ignore_index) 66 | 67 | mat_model = cinit(SimpleMatrixHead, cfg.model) 68 | mat_model = mat_model.to(rank) 69 | 70 | criterion = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=-1, label_smoothing=False).to(rank) 71 | 72 | epochs = 20 73 | 74 | opti = cinit(OptiMaster, cfg.opti, model=mat_model, 75 | epochs=epochs, iter_per_epoch=cfg.iter_per_epoch) 76 | 77 | log.log("train_set_size", train_iter.set_size) 78 | log.log("valid_set_size", valid_iter.set_size) 79 | log.log("test_set_size", test_iter.set_size) 80 | log.log("model_parameters", count_parameters(mat_model.parameters())) 81 | 82 | for e in range(epochs): 83 | 84 | log(f"Start training epoch: ", e) 85 | 86 | for i, batch in tqdm(enumerate(train_iter)): 87 | 88 | if i == cfg.iter_per_epoch: 89 | break 90 | 91 | pred_mat, mask = mat_model(latent=batch.raw_latent, src=batch.src_seq, pred=batch.pred_struct, 92 | src_len=batch.length[:, 0]) 93 | 94 | mask = mask[:, 0, :, :] 95 | trg_mat = batch.true_mat 96 | 97 | pred_mat = pred_mat * mask[:, :, :, None] 98 | trg_mat = trg_mat * mask 99 | 100 | loss_1 = criterion(pred_mat.permute(0, 3, 1, 2), trg_mat.to(torch.long)) 101 | 102 | loss = loss_1 * mask 103 | 104 | def top_k_masking(loss, src_len, k_percent): 105 | with torch.no_grad(): 106 | mask = torch.zeros_like(loss) 107 | for b in range(loss.shape[0]): 108 | k = max(2, int(src_len[b] ** 2 * (k_percent / 100))) 109 | idx = torch.topk(loss[b].view(-1), k=k)[1] 110 | mask[b].view(-1)[idx] = 1 111 | return loss * mask, mask 112 | 113 | loss, mask = top_k_masking(loss, batch.length[:, 0], k_percent=cfg.k_percent) 114 | 115 | loss = torch.sum(loss, dim=(1, 2)) / torch.sum(mask, dim=(1, 2)) 116 | loss = loss.mean() 117 | 118 | if i % 1000 == 0: 119 | log(f"loss", loss) 120 | 121 | loss.backward() 122 | 123 | opti.optimizer.step() 124 | opti.train_step() 125 | opti.optimizer.zero_grad() 126 | 127 | torch.save(mat_model, ckp.dir / f"checkpoint_{e}.pth") 128 | 129 | mat_model.eval() 130 | 131 | evaluations_mat = [] 132 | for batch in tqdm(valid_iter): 133 | 134 | with torch.no_grad(): 135 | b_pred_mat, mask = mat_model(latent=batch.raw_latent, src=batch.src_seq, pred=batch.pred_struct, 136 | src_len=batch.length[:, 0]) 137 | 138 | for b, l in enumerate(batch.length[:, 0]): 139 | pred_mat = torch.sigmoid(b_pred_mat[b, :l, :l, 1]) 140 | true_mat = batch.true_mat[b, :l, :l] 141 | 142 | pred_mat = torch.triu(pred_mat, diagonal=1).t() + torch.triu(pred_mat, diagonal=1) 143 | 144 | sequence = [rna_data.seq_itos[i] for i in 145 | batch.src_seq[b, :l].detach().cpu().numpy()] 146 | 147 | sample = {"true": true_mat.detach().cpu().numpy(), "pred": pred_mat.detach().cpu().numpy(), 148 | "sequence": sequence, } 149 | evaluations_mat.append(sample) 150 | 151 | new_stats = StatisticsCenter(evaluations_mat, step_size=0.2, triangle_loss=False) 152 | metrics, threshold = new_stats.find_best_threshold() 153 | for key, value in metrics.items(): 154 | log(f"valid_{key}", value) 155 | 156 | log(f"### test epoch ", e) 157 | mat_model.eval() 158 | 159 | evaluations_mat = [] 160 | for batch in tqdm(test_iter): 161 | 162 | with torch.no_grad(): 163 | b_pred_mat, mask = mat_model(latent=batch.raw_latent, src=batch.src_seq, pred=batch.pred_struct, 164 | src_len=batch.length[:, 0]) 165 | 166 | for b, l in enumerate(batch.length[:, 0]): 167 | pred_mat = torch.sigmoid(b_pred_mat[b, :l, :l, 1]) 168 | true_mat = batch.true_mat[b, :l, :l] 169 | 170 | pred_mat = torch.triu(pred_mat, diagonal=1).t() + torch.triu(pred_mat, diagonal=1) 171 | 172 | sequence = [rna_data.seq_itos[i] for i in 173 | batch.src_seq[b, :l].detach().cpu().numpy()] 174 | 175 | sample = {"true": true_mat.detach().cpu().numpy(), "pred": pred_mat.detach().cpu().numpy(), 176 | "sequence": sequence, } 177 | evaluations_mat.append(sample) 178 | 179 | new_stats = StatisticsCenter(evaluations_mat, step_size=0.2, triangle_loss=False) 180 | metrics, threshold = new_stats.find_best_threshold() 181 | for key, value in metrics.items(): 182 | log(f"test_{key}", value) 183 | 184 | 185 | if __name__ == "__main__": 186 | import argparse 187 | 188 | parser = argparse.ArgumentParser(description='Evaluate the model given a checkpoint file.') 189 | parser.add_argument('-d', '--model_dir', type=str, help='the directory of a model') 190 | parser.add_argument('-f', '--model_file', type=str, help='the checkpoint file') 191 | 192 | args = parser.parse_args() 193 | 194 | model_dir = args.model_dir 195 | model_file = args.model_file 196 | 197 | config = {} 198 | 199 | config['model_dir'] = model_dir 200 | config['model_file'] = model_file 201 | config['expt'] = { 202 | "project_name": "tmp_test", 203 | "session_name": "cnn_head", 204 | "experiment_name": "test_model", 205 | "job_name": "local_run", 206 | "save_model": False, 207 | "resume_training": False, 208 | } 209 | 210 | config['seed'] = 5786 211 | config['batch_size'] = 1000 212 | config['max_train_len'] = 200 213 | config['k_percent'] = 0.4 214 | config['iter_per_epoch'] = 1000 215 | config['data'] = {'type':'rna'} 216 | 217 | config['opti'] = {"optimizer": 'adamW', # 'adam', 218 | "scheduler": 'cosine', 219 | "warmup_epochs": 1, # 1, 220 | "lr_low": 0.0001, 221 | "lr_high": 0.001, 222 | "beta1": 0.9, 223 | "beta2": 0.98, # 0.98, 224 | "weight_decay": 1e-10, # 1e-8, 225 | "factor": 1, 226 | "swa": False, 227 | "swa_start_epoch": 0, 228 | "swa_lr": 0, 229 | "plateua_metric": None} 230 | 231 | config['model'] = {"src_vocab_size": 5, 232 | "latent_dim": 512, 233 | "dropout": 0.1, 234 | "model_dim": 64, 235 | "out_channels": 2, 236 | "res_layer": 3, 237 | "kernel": 5, 238 | "max_len": 500, } 239 | 240 | run_cnn_training(config, "experiments") 241 | -------------------------------------------------------------------------------- /prob_transformer/evaluation/eval_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import socket 3 | import numpy as np 4 | import torch 5 | import random 6 | from prob_transformer.utils.handler.config import ConfigHandler 7 | from prob_transformer.utils.logger import Logger 8 | from prob_transformer.utils.summary import SummaryDict 9 | from prob_transformer.utils.config_init import cinit 10 | from prob_transformer.utils.torch_utils import count_parameters 11 | 12 | from prob_transformer.model.probtransformer import ProbTransformer 13 | from prob_transformer.data.iterator import MyIterator 14 | from prob_transformer.data.rna_handler import RNAHandler 15 | from prob_transformer.data.ssd_handler import SSDHandler 16 | from prob_transformer.data.mol_handler import MolHandler 17 | 18 | from prob_transformer.routine.evaluation import run_evaluation 19 | 20 | 21 | def eval_transformer(checkpoint): 22 | cfg = ConfigHandler(config_dict=checkpoint['config']) 23 | 24 | infer_seed = cfg.train.seed 25 | torch.manual_seed(infer_seed) 26 | random.seed(infer_seed) 27 | np.random.seed(infer_seed) 28 | 29 | log = Logger("experiment", file_name="eval_log_file.txt") 30 | 31 | torch.backends.cudnn.deterministic = False 32 | torch.backends.cudnn.benchmark = True 33 | rank = 0 if torch.cuda.is_available() else "cpu" 34 | 35 | log.log(f"### START EVALUATION ### at {socket.gethostname()}") 36 | 37 | ############################################################ 38 | ####### DATA ITERATOR ######## 39 | ############################################################ 40 | log.log(f"### load data", rank=rank) 41 | 42 | num_props = False 43 | 44 | if cfg.data.type == "rna": 45 | 46 | ignore_index = -1 47 | pad_index = 0 48 | 49 | train_data = cinit(RNAHandler, cfg.data.rna, df_path='data/rna_data.plk', sub_set='train', prob_training="prob" in cfg.model.model_type, 50 | device=rank, seed=cfg.data.seed, ignore_index=ignore_index) 51 | 52 | valid_data = cinit(RNAHandler, cfg.data.rna, df_path='data/rna_data.plk', sub_set='valid', prob_training=False, device=rank, 53 | seed=cfg.data.seed, ignore_index=ignore_index) 54 | 55 | test_data = cinit(RNAHandler, cfg.data.rna, df_path='data/rna_data.plk', sub_set='test', prob_training=False, device=rank, 56 | seed=cfg.data.seed, ignore_index=ignore_index) 57 | 58 | seq_vocab_size = train_data.seq_vocab_size 59 | trg_vocab_size = train_data.struct_vocab_size 60 | 61 | elif cfg.data.type == "ssd": 62 | 63 | ignore_index = -1 64 | pad_index = 0 65 | 66 | train_data = cinit(SSDHandler, cfg.data.ssd, sample_amount=cfg.data.ssd.sample_amount, device=rank, 67 | pre_src_vocab=None, pre_trg_vocab=None, token_dict=None) 68 | valid_data = cinit(SSDHandler, cfg.data.ssd, sample_amount=cfg.data.ssd.sample_amount // 10, device=rank, 69 | pre_src_vocab=train_data.pre_src_vocab, pre_trg_vocab=train_data.pre_trg_vocab, 70 | token_dict=train_data.token_dict) 71 | test_data = cinit(SSDHandler, cfg.data.ssd, sample_amount=cfg.data.ssd.sample_amount // 10, device=rank, 72 | pre_src_vocab=train_data.pre_src_vocab, pre_trg_vocab=train_data.pre_trg_vocab, 73 | token_dict=train_data.token_dict) 74 | 75 | seq_vocab_size = train_data.source_vocab_size 76 | trg_vocab_size = train_data.target_vocab_size 77 | 78 | 79 | elif cfg.data.type == "mol": 80 | 81 | train_data = cinit(MolHandler, cfg.data.mol, split="train", device=rank) 82 | valid_data = cinit(MolHandler, cfg.data.mol, split="valid", device=rank) 83 | test_data = cinit(MolHandler, cfg.data.mol, split="test", device=rank) 84 | 85 | ignore_index = -1 86 | pad_index = train_data.ignore_index 87 | 88 | if isinstance(cfg.data.mol.props, List): 89 | num_props = len(cfg.data.mol.props) 90 | 91 | if "decoder" in cfg.model.model_type: 92 | seq_vocab_size = train_data.target_vocab_size 93 | else: 94 | seq_vocab_size = 1 95 | trg_vocab_size = train_data.target_vocab_size 96 | log(f"trg_vocab_size: {trg_vocab_size}") 97 | 98 | else: 99 | raise UserWarning(f"data type unknown: {cfg.data.type}") 100 | 101 | log.log(f"### load iterator", rank=rank) 102 | valid_iter = MyIterator(data_handler=valid_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 103 | batching=True, pre_sort_samples=False, 104 | device=rank, seed=cfg.data.seed + rank, ignore_index=ignore_index, pad_index=pad_index) 105 | 106 | test_iter = MyIterator(data_handler=test_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 107 | batching=False, pre_sort_samples=False, 108 | device=rank, seed=cfg.data.seed + rank, ignore_index=ignore_index, pad_index=pad_index) 109 | 110 | log.log("valid_set_size", valid_iter.set_size, rank=rank) 111 | log.log("test_set_size", test_iter.set_size, rank=rank) 112 | 113 | log("src_vocab_len", seq_vocab_size) 114 | log("tgt_vocab_len", trg_vocab_size) 115 | 116 | ############################################################ 117 | ####### BUILD MODEL ######## 118 | ############################################################ 119 | model = cinit(ProbTransformer, cfg.model, seq_vocab_size=seq_vocab_size, trg_vocab_size=trg_vocab_size, 120 | props=num_props) 121 | model.load_state_dict(checkpoint['state_dict'], strict=False) 122 | 123 | log.log("model_parameters", count_parameters(model.parameters()), rank=rank) 124 | 125 | model = model.to(rank) 126 | 127 | ############################################################ 128 | ####### START EVALUATION ######## 129 | ############################################################ 130 | eval_summary = SummaryDict() 131 | log.start_timer(f"eval", rank=rank) 132 | log("## Start valid evaluation") 133 | score_dict_valid = run_evaluation(cfg, valid_iter, model) 134 | for name, score in score_dict_valid.items(): 135 | log(f"{name}_valid", score, rank=rank) 136 | eval_summary[f"{name}_valid"] = score 137 | log.timer(f"eval", rank=rank) 138 | 139 | ########################################################### 140 | ###### TEST MODEL ######## 141 | ########################################################### 142 | log.start_timer(f"test") 143 | score_dict_test = run_evaluation(cfg, test_iter, model, threshold=score_dict_valid['threshold']) 144 | for name, score in score_dict_test.items(): 145 | log(f"{name}_test", score) 146 | eval_summary[f"{name}_test"] = score 147 | log.timer(f"test") 148 | log.save_to_json(rank=rank) 149 | eval_summary.save(cfg.expt.experiment_dir / "evaluation.npy") 150 | 151 | 152 | if __name__ == "__main__": 153 | import argparse 154 | 155 | parser = argparse.ArgumentParser(description='Evaluate the model given a checkpoint file.') 156 | parser.add_argument('-c', '--checkpoint', type=str, help='a checkpoint file') 157 | 158 | args = parser.parse_args() 159 | 160 | checkpoint = torch.load(args.checkpoint, 161 | map_location=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) 162 | 163 | eval_transformer(checkpoint=checkpoint) 164 | -------------------------------------------------------------------------------- /prob_transformer/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/evaluation/metrics/__init__.py -------------------------------------------------------------------------------- /prob_transformer/evaluation/metrics/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/evaluation/metrics/fpscores.pkl.gz -------------------------------------------------------------------------------- /prob_transformer/evaluation/metrics/sascore.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | from rdkit import Chem 20 | from rdkit.Chem import rdMolDescriptors 21 | import pickle 22 | 23 | import math 24 | from collections import defaultdict 25 | 26 | import os.path as op 27 | 28 | _fscores = None 29 | 30 | 31 | def readFragmentScores(name='fpscores'): 32 | import gzip 33 | global _fscores 34 | # generate the full path filename: 35 | if name == "fpscores": 36 | name = op.join(op.dirname(__file__), name) 37 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 38 | outDict = {} 39 | for i in data: 40 | for j in range(1, len(i)): 41 | outDict[i[j]] = float(i[0]) 42 | _fscores = outDict 43 | 44 | 45 | def numBridgeheadsAndSpiro(mol, ri=None): 46 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 47 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 48 | return nBridgehead, nSpiro 49 | 50 | 51 | def calculateScore(m): 52 | if _fscores is None: 53 | readFragmentScores() 54 | 55 | # fragment score 56 | fp = rdMolDescriptors.GetMorganFingerprint(m, 57 | 2) # <- 2 is the *radius* of the circular fingerprint 58 | fps = fp.GetNonzeroElements() 59 | score1 = 0. 60 | nf = 0 61 | for bitId, v in fps.items(): 62 | nf += v 63 | sfp = bitId 64 | score1 += _fscores.get(sfp, -4) * v 65 | score1 /= nf 66 | 67 | # features score 68 | nAtoms = m.GetNumAtoms() 69 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 70 | ri = m.GetRingInfo() 71 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 72 | nMacrocycles = 0 73 | for x in ri.AtomRings(): 74 | if len(x) > 8: 75 | nMacrocycles += 1 76 | 77 | sizePenalty = nAtoms**1.005 - nAtoms 78 | stereoPenalty = math.log10(nChiralCenters + 1) 79 | spiroPenalty = math.log10(nSpiro + 1) 80 | bridgePenalty = math.log10(nBridgeheads + 1) 81 | macrocyclePenalty = 0. 82 | # --------------------------------------- 83 | # This differs from the paper, which defines: 84 | # macrocyclePenalty = math.log10(nMacrocycles+1) 85 | # This form generates better results when 2 or more macrocycles are present 86 | if nMacrocycles > 0: 87 | macrocyclePenalty = math.log10(2) 88 | 89 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 90 | 91 | # correction for the fingerprint density 92 | # not in the original publication, added in version 1.1 93 | # to make highly symmetrical molecules easier to synthetise 94 | score3 = 0. 95 | if nAtoms > len(fps): 96 | score3 = math.log(float(nAtoms) / len(fps)) * .5 97 | 98 | sascore = score1 + score2 + score3 99 | 100 | # need to transform "raw" value into scale between 1 and 10 101 | min = -4.0 102 | max = 2.5 103 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 104 | # smooth the 10-end 105 | if sascore > 8.: 106 | sascore = 8. + math.log(sascore + 1. - 9.) 107 | if sascore > 10.: 108 | sascore = 10.0 109 | elif sascore < 1.: 110 | sascore = 1.0 111 | 112 | return sascore 113 | 114 | 115 | def processMols(mols): 116 | print('smiles\tName\tsa_score') 117 | for i, m in enumerate(mols): 118 | if m is None: 119 | continue 120 | 121 | s = calculateScore(m) 122 | 123 | smiles = Chem.MolToSmiles(m) 124 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 125 | 126 | 127 | if __name__ == '__main__': 128 | import sys 129 | import time 130 | 131 | t1 = time.time() 132 | readFragmentScores("fpscores") 133 | t2 = time.time() 134 | 135 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 136 | t3 = time.time() 137 | processMols(suppl) 138 | t4 = time.time() 139 | 140 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 141 | file=sys.stderr) 142 | # 143 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 144 | # All rights reserved. 145 | # 146 | # Redistribution and use in source and binary forms, with or without 147 | # modification, are permitted provided that the following conditions are 148 | # met: 149 | # 150 | # * Redistributions of source code must retain the above copyright 151 | # notice, this list of conditions and the following disclaimer. 152 | # * Redistributions in binary form must reproduce the above 153 | # copyright notice, this list of conditions and the following 154 | # disclaimer in the documentation and/or other materials provided 155 | # with the distribution. 156 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 157 | # nor the names of its contributors may be used to endorse or promote 158 | # products derived from this software without specific prior written permission. 159 | # 160 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 161 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 162 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 163 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 164 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 165 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 166 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 167 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 168 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 169 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 170 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 171 | # -------------------------------------------------------------------------------- /prob_transformer/evaluation/metrics/toy_task_survey.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import distance 4 | 5 | def eval_toy_sample(n_pred_dist, true_dist, length): 6 | true_dist = true_dist.to(n_pred_dist[0].device) 7 | true_binary = true_dist > 0 8 | n_pred_binary = torch.zeros_like(n_pred_dist[0]).to(n_pred_dist[0].device) 9 | 10 | seq_solved_list, correct_symbols_list = [], [] 11 | 12 | for pred_dist in n_pred_dist: 13 | 14 | eos_idx = length 15 | 16 | sample_pred_dist = pred_dist[:eos_idx, :] 17 | sample_pred_max_idx = torch.max(sample_pred_dist, dim=1, keepdim=True)[0] 18 | sample_pred_binary = sample_pred_dist.ge(sample_pred_max_idx) 19 | 20 | n_pred_binary[:eos_idx, :] += sample_pred_binary[:eos_idx, :] 21 | 22 | sample_true_binary = true_binary[:eos_idx, :] 23 | 24 | if eos_idx > sample_true_binary.shape[0]: 25 | correction_true = torch.zeros_like(sample_pred_binary).to(n_pred_dist[0].device) 26 | correction_true[:eos_idx, :] = sample_true_binary 27 | sample_true_binary = correction_true 28 | elif eos_idx < sample_true_binary.shape[0]: 29 | correction_true = torch.zeros_like(sample_true_binary).to(n_pred_dist[0].device) 30 | correction_true[:eos_idx, :] = sample_pred_binary 31 | sample_pred_binary = correction_true 32 | 33 | correct_symbols = sample_pred_binary.logical_and(sample_true_binary) 34 | correct_symbols = torch.sum(correct_symbols, dim=1, dtype=torch.float) 35 | mean_correct_symbols = torch.mean(correct_symbols) 36 | seq_solved = int(mean_correct_symbols.ge(1).cpu().detach().numpy()) 37 | 38 | correct_symbols = mean_correct_symbols.cpu().detach().numpy() 39 | 40 | seq_solved_list.append(seq_solved) 41 | correct_symbols_list.append(correct_symbols) 42 | 43 | avg_pred_binary = n_pred_binary / n_pred_dist.shape[0] 44 | 45 | samplewise_kl = torch.sum(true_dist * torch.log(true_dist / (avg_pred_binary + 1e-32) + 1e-32), dim=1, keepdim=True) 46 | samplewise_kl = torch.mean(samplewise_kl).cpu().detach().numpy() 47 | 48 | mean_div = torch.mean(0.5 * torch.sum(torch.abs(true_dist - avg_pred_binary), dim=1, keepdim=True)).cpu().detach().numpy() 49 | 50 | levenshtein = [] 51 | for t, p in zip(true_dist, avg_pred_binary): 52 | levenshtein.append(distance.levenshtein(torch.where(t>0)[0].cpu().detach().tolist(), torch.where(p>0)[0].cpu().detach().tolist())) 53 | 54 | options_per_symbol = torch.sum(true_binary, dim=1, keepdim=True) 55 | choices_per_symbol = torch.sum(avg_pred_binary.ge(0.01), dim=1, keepdim=True, dtype=torch.float) 56 | diversity = (choices_per_symbol * options_per_symbol) / ( 57 | options_per_symbol * options_per_symbol + 1e-9) # additional multiplication to zero out blanks 58 | diversity = torch.mean(diversity).cpu().detach().numpy() 59 | 60 | return {"seq_solved": np.mean(seq_solved_list), "correct_symbols": np.mean(correct_symbols_list), 61 | "samplewise_kl": samplewise_kl, "diversity": diversity, "total_variation":mean_div, 62 | "levenshtein":np.mean(levenshtein) } 63 | 64 | 65 | def eval_toy_task(batch_pred_dist, batch_src, trg_length, data_set): 66 | metrics_list = [] 67 | batch_true_dist = data_set.get_batch_dist(batch_src) 68 | 69 | for n_pred_dist, true_dist, length in zip(batch_pred_dist, batch_true_dist, trg_length): 70 | metrics_list.append(eval_toy_sample(n_pred_dist, true_dist, length)) 71 | return metrics_list 72 | -------------------------------------------------------------------------------- /prob_transformer/evaluation/statistics_center.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from itertools import zip_longest, product 4 | 5 | 6 | def symmetric_matrix(mat): 7 | return np.maximum(mat, mat.T) 8 | 9 | 10 | def sigmoid(x): 11 | return 1 / (1 + np.exp(-x)) 12 | 13 | 14 | class StatisticsCenter(): 15 | def __init__(self, 16 | pred_list, 17 | step_size=0.01, 18 | symmetric_matrices=True, 19 | full_eval=False, 20 | triangle_loss=False, 21 | ): 22 | self.symmetric_matrices = symmetric_matrices 23 | self.full_eval = full_eval 24 | self.triangle_loss = triangle_loss 25 | 26 | self.types2evaluate = ['wc', 'wobble', 'nc', 'canonical', 'all_pairs'] 27 | self.wc_pairs = ['GC', 'CG', 'AU', 'UA'] 28 | self.wobble_pairs = ['GU', 'UG'] 29 | 30 | self.pred_list = pred_list 31 | 32 | self.metrics = defaultdict(list) 33 | self.predictions = defaultdict(list) 34 | 35 | self.thresholds2evaluate = np.arange(0.0, 1.0, step_size) 36 | self.evaluated_thresholds = [] 37 | 38 | def eval_prediction(self, true_mat, pred_mat, mask_dict, threshold): 39 | metrics = {} 40 | 41 | if self.triangle_loss: 42 | idx1, idx2 = np.triu_indices(pred_mat.shape[1], 1) 43 | 44 | pred_pair_sec1 = pred_mat[idx1, idx2] 45 | pred_pair_sec2 = pred_mat.transpose(1, 0)[idx1, idx2] 46 | pred_pair_sec = (pred_pair_sec1 + pred_pair_sec2) / 2 47 | pred_pair_sec = sigmoid(pred_pair_sec) 48 | pred_pair_sec = pred_pair_sec > threshold 49 | 50 | pred = np.zeros(pred_mat.shape) 51 | pred[idx1, idx2] = pred_pair_sec 52 | if self.symmetric_matrices: 53 | pred[idx2, idx1] = pred_pair_sec 54 | else: 55 | 56 | pred_mat = sigmoid(pred_mat) 57 | pred = pred_mat > threshold 58 | if self.symmetric_matrices: 59 | pred = symmetric_matrix(pred) 60 | 61 | pred = pred.flatten().astype(int) 62 | true = true_mat.flatten() 63 | 64 | for key, mask_mat in mask_dict.items(): 65 | 66 | if key != 'all': 67 | mask = mask_mat.flatten() 68 | del_pos = np.where(mask == 0)[0] 69 | 70 | pred_del = np.delete(pred, del_pos) 71 | true_del = np.delete(true, del_pos) 72 | 73 | k_metrics = self.eval_array(pred_del, true_del) 74 | else: 75 | k_metrics = self.eval_array(pred, true) 76 | k_metrics = {f"{key}_{k}": v for k, v in k_metrics.items()} 77 | metrics.update(k_metrics) 78 | return metrics 79 | 80 | def eval_array(self, pred, true): 81 | 82 | solved = np.all(np.equal(true, pred)).astype(int) 83 | if solved == 1: 84 | f1_score = 1 85 | non_correct = 0 86 | precision = 1 87 | recall = 1 88 | specificity = 1 89 | else: 90 | tp = np.logical_and(pred, true).sum() 91 | non_correct = (tp == 0).astype(int) 92 | tn = np.logical_and(np.logical_not(pred), np.logical_not(true)).sum() 93 | fp = pred.sum() - tp 94 | fn = true.sum() - tp 95 | 96 | recall = tp / (tp + fn + 1e-8) 97 | precision = tp / (tp + fp + 1e-8) 98 | specificity = tn / (tn + fp + 1e-8) 99 | f1_score = 2 * tp / (2 * tp + fp + fn) 100 | 101 | metrics = {'f1_score': f1_score, 'solved': solved} 102 | if self.full_eval: 103 | metrics['non_correct'] = non_correct 104 | metrics['precision'] = precision 105 | metrics['recall'] = recall 106 | metrics['specificity'] = specificity 107 | 108 | return metrics 109 | 110 | def get_pair_type_masks(self, sequence): 111 | all_mask = np.ones((len(sequence), len(sequence))) 112 | wc = np.zeros((len(sequence), len(sequence))) 113 | wobble = np.zeros((len(sequence), len(sequence))) 114 | 115 | a = [i for i, sym in enumerate(sequence) if sym.upper() == 'A'] 116 | c = [i for i, sym in enumerate(sequence) if sym.upper() == 'C'] 117 | g = [i for i, sym in enumerate(sequence) if sym.upper() == 'G'] 118 | u = [i for i, sym in enumerate(sequence) if sym.upper() == 'U'] 119 | 120 | for wc1, wc2, wob in zip_longest(product(g, c), product(a, u), product(g, u), fillvalue=None): 121 | if wc1: 122 | wc[wc1[0], wc1[1]] = 1 123 | wc[wc1[1], wc1[0]] = 1 124 | if wc2: 125 | wc[wc2[0], wc2[1]] = 1 126 | wc[wc2[1], wc2[0]] = 1 127 | if wob: 128 | wobble[wob[0], wob[1]] = 1 129 | wobble[wob[1], wob[0]] = 1 130 | 131 | canonical = wc + wobble 132 | nc = all_mask - canonical 133 | 134 | return {'all': all_mask, 'wc': wc, 'canonical': canonical, 'wobble': wobble, 'nc': nc} 135 | 136 | def eval_pred(self, pred_sample, threshold): 137 | 138 | pred_mat = pred_sample['pred'] 139 | true_mat = pred_sample['true'].astype(int) 140 | 141 | mask_dict = self.get_pair_type_masks(pred_sample['sequence']) 142 | metrics = self.eval_prediction(true_mat, pred_mat, mask_dict, threshold) 143 | 144 | return metrics 145 | 146 | def eval_threshold(self, threshold): 147 | 148 | assert 0.0 <= threshold <= 1.0 149 | 150 | metrics_list = list(map(lambda x: self.eval_pred(x, threshold), self.pred_list)) 151 | metrics = {k: np.mean([dic[k] for dic in metrics_list]) for k in metrics_list[0]} 152 | 153 | return metrics 154 | 155 | def find_best_threshold(self): 156 | 157 | best_threshold = 0 158 | best_all_f1_score = 0 159 | best_metric = {} 160 | 161 | for threshold in self.thresholds2evaluate: 162 | metrics = self.eval_threshold(threshold) 163 | 164 | if metrics['all_f1_score'] > best_all_f1_score: 165 | best_all_f1_score = metrics['all_f1_score'] 166 | best_metric = metrics 167 | best_threshold = threshold 168 | 169 | return best_metric, best_threshold 170 | -------------------------------------------------------------------------------- /prob_transformer/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/model/__init__.py -------------------------------------------------------------------------------- /prob_transformer/model/probtransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.cuda.amp as amp 4 | 5 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.module.probformer_stack import ProbFormerStack 6 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.module.embedding import PosEmbedding 7 | 8 | 9 | class ProbTransformer(nn.Module): 10 | 11 | def __init__(self, model_type, seq_vocab_size, trg_vocab_size, model_dim, max_len, n_layers, 12 | num_head, ff_factor, z_factor, dropout, prob_layer, props=False, zero_init=True): 13 | super().__init__() 14 | 15 | self.n_layers = n_layers 16 | self.max_len = max_len 17 | self.zero_init = zero_init 18 | 19 | self.props = props 20 | 21 | self.model_type = model_type 22 | assert self.model_type in ['encoder', 'prob_encoder', 23 | 'encoder_decoder', 'encoder_prob_decoder', 24 | 'decoder', 'prob_decoder'] 25 | 26 | self.probabilistic = "prob" in model_type 27 | self.encoder = "encoder" in model_type 28 | self.decoder = "decoder" in model_type 29 | 30 | if self.probabilistic: 31 | if isinstance(prob_layer, str): 32 | if prob_layer == 'all': 33 | self.prob_layer = list(range(n_layers)) 34 | elif prob_layer == 'middle': 35 | self.prob_layer = list(range(n_layers))[1:-1] 36 | elif prob_layer == 'first': 37 | self.prob_layer = [range(n_layers)[0]] 38 | elif prob_layer == 'last': 39 | self.prob_layer = [range(n_layers)[-1]] 40 | else: 41 | self.prob_layer = prob_layer 42 | else: 43 | self.prob_layer = [] 44 | 45 | if 'encoder' in model_type: 46 | self.encoder = ProbFormerStack(n_layers=n_layers, model_dim=model_dim, num_head=num_head, 47 | ff_factor=ff_factor, z_factor=z_factor, dropout=dropout, zero_init=zero_init, 48 | cross_attention=False, posterior=False, 49 | prob_layer=self.prob_layer if 'prob_encoder' in model_type else []) 50 | 51 | if 'prob_encoder' == model_type: 52 | self.post_encoder = ProbFormerStack(n_layers=n_layers, model_dim=model_dim, num_head=num_head, 53 | ff_factor=ff_factor, z_factor=z_factor, dropout=dropout, 54 | zero_init=zero_init, cross_attention=False, 55 | posterior=True, prob_layer=self.prob_layer) 56 | 57 | if 'encoder_decoder' == model_type or 'encoder_prob_decoder' == model_type: 58 | self.decoder = ProbFormerStack(n_layers=n_layers, model_dim=model_dim, num_head=num_head, 59 | ff_factor=ff_factor, z_factor=z_factor, dropout=dropout, 60 | zero_init=zero_init, cross_attention=True, 61 | posterior=False, 62 | prob_layer=self.prob_layer if 'prob_decoder' in model_type else []) 63 | 64 | if 'encoder_prob_decoder' == model_type: 65 | self.post_decoder = ProbFormerStack(n_layers=n_layers, model_dim=model_dim, num_head=num_head, 66 | ff_factor=ff_factor, z_factor=z_factor, dropout=dropout, 67 | zero_init=zero_init, cross_attention=False, 68 | posterior=True, prob_layer=self.prob_layer) 69 | 70 | if 'decoder' == model_type or 'prob_decoder' == model_type: 71 | self.decoder = ProbFormerStack(n_layers=n_layers, model_dim=model_dim, num_head=num_head, 72 | ff_factor=ff_factor, 73 | z_factor=z_factor, dropout=dropout, zero_init=zero_init, 74 | cross_attention=False, posterior=False, 75 | prob_layer=self.prob_layer if 'prob_decoder' in model_type else []) 76 | 77 | if 'prob_decoder' == model_type: 78 | self.post_decoder = ProbFormerStack(n_layers=n_layers, model_dim=model_dim, num_head=num_head, 79 | ff_factor=ff_factor, z_factor=z_factor, dropout=dropout, 80 | zero_init=zero_init, cross_attention=False, 81 | posterior=True, prob_layer=self.prob_layer) 82 | 83 | if 'encoder' in model_type: 84 | self.src_embed = PosEmbedding(seq_vocab_size, model_dim, max_len) 85 | 86 | if self.decoder: 87 | self.trg_embed = PosEmbedding(trg_vocab_size, model_dim, max_len) 88 | 89 | if self.props: 90 | self.type_embed = nn.Embedding(2, model_dim) 91 | 92 | if self.props: 93 | self.prop_embed = nn.Linear(self.props, model_dim) 94 | 95 | if self.probabilistic: 96 | if 'prob_encoder' in model_type: 97 | self.post_encoder_embed_seq = nn.Embedding(seq_vocab_size, model_dim) 98 | self.post_encoder_embed_trg = nn.Embedding(trg_vocab_size, model_dim) 99 | 100 | if 'prob_decoder' in model_type: 101 | self.post_decoder_embed_post = nn.Embedding(trg_vocab_size, model_dim) 102 | self.post_decoder_embed_trg = nn.Embedding(trg_vocab_size, model_dim) 103 | 104 | if self.props: 105 | self.post_prop_embed = nn.Linear(self.props, model_dim) 106 | 107 | self.output_ln = nn.LayerNorm(model_dim) 108 | self.output = nn.Linear(model_dim, trg_vocab_size) 109 | 110 | self.initialize() 111 | 112 | def initialize(self): 113 | 114 | # embedding initialization based on https://arxiv.org/abs/1711.09160 115 | if self.encoder: 116 | nn.init.normal_(self.src_embed.embed_seq.weight, mean=0.0, std=0.0001) 117 | 118 | if self.decoder: 119 | nn.init.normal_(self.trg_embed.embed_seq.weight, mean=0.0, std=0.0001) 120 | 121 | if self.props: 122 | nn.init.normal_(self.prop_embed.weight, mean=0.0, std=0.001) 123 | 124 | if self.probabilistic: 125 | if 'prob_encoder' in self.model_type: 126 | nn.init.normal_(self.post_encoder_embed_seq.weight, mean=0.0, std=0.0001) 127 | nn.init.normal_(self.post_encoder_embed_trg.weight, mean=0.0, std=0.0001) 128 | 129 | if 'prob_decoder' in self.model_type: 130 | nn.init.normal_(self.post_decoder_embed_post.weight, mean=0.0, std=0.0001) 131 | nn.init.normal_(self.post_decoder_embed_trg.weight, mean=0.0, std=0.0001) 132 | 133 | if self.props: 134 | nn.init.normal_(self.prop_embed.weight, mean=0.0, std=0.001) 135 | 136 | nn.init.xavier_uniform_(self.output.weight) 137 | nn.init.constant_(self.output.bias, 0.0) 138 | 139 | def make_src_mask(self, src, src_len): 140 | src_mask = torch.arange(src.shape[1], device=src.device).expand(src.shape[:2]) < src_len.unsqueeze(1) 141 | src_mask = src_mask.type(src.type()) 142 | return src_mask 143 | 144 | def make_trg_mask(self, trg_embed, trg_len): 145 | mask = torch.arange(trg_embed.size()[1], device=trg_embed.device).expand( 146 | trg_embed.shape[:2]) < trg_len.unsqueeze(1) 147 | mask = mask.unsqueeze(-1) 148 | sub_mask = torch.triu( 149 | torch.ones((1, trg_embed.size()[1], trg_embed.size()[1]), dtype=torch.bool, device=trg_embed.device), 150 | diagonal=1) 151 | sub_mask = sub_mask == 0 152 | trg_mask = mask & sub_mask 153 | trg_mask = trg_mask.type(trg_embed.type()) 154 | return trg_mask 155 | 156 | def forward(self, src_seq=None, src_len=None, post_trg_seq=None, trg_shf_seq=None, trg_len=None, 157 | props=None, infer_mean=False, output_latent=False): 158 | 159 | if self.encoder: 160 | src_mask = self.make_src_mask(src_seq, src_len) 161 | seq_embed = self.src_embed(src_seq) # * src_mask[:, :, None] 162 | if torch.is_autocast_enabled(): 163 | src_mask = src_mask.half() 164 | seq_embed = seq_embed.half() 165 | 166 | if props is not None: 167 | prop_embed = self.prop_embed(props) 168 | seq_embed = seq_embed + prop_embed 169 | 170 | if self.decoder: 171 | trg_shift_embed = self.trg_embed(trg_shf_seq) 172 | 173 | if props is not None: 174 | type_embed = self.type_embed(torch.zeros((trg_shf_seq.shape[0], 1), dtype=torch.long, 175 | device=trg_shf_seq.device)) 176 | prop_embed = self.prop_embed(props) + type_embed 177 | trg_shift_embed = torch.cat([prop_embed, trg_shift_embed], 1) 178 | trg_len = trg_len + 1 179 | 180 | trg_shift_mask = self.make_trg_mask(trg_shift_embed, trg_len) 181 | 182 | if torch.is_autocast_enabled(): 183 | trg_shift_mask = trg_shift_mask.half() 184 | trg_shift_embed = trg_shift_embed.half() 185 | 186 | if self.probabilistic and post_trg_seq != None: 187 | if 'prob_encoder' in self.model_type: 188 | post_seq_encoder_embed = self.post_encoder_embed_seq(src_seq) 189 | post_trg_encoder_embed = self.post_encoder_embed_trg(post_trg_seq) 190 | post_encoder_embed = post_seq_encoder_embed + post_trg_encoder_embed 191 | if torch.is_autocast_enabled(): 192 | post_encoder_embed = post_encoder_embed.half() 193 | 194 | if 'prob_decoder' in self.model_type: 195 | post_trg_decoder_embed = self.post_decoder_embed_trg(trg_shf_seq) 196 | post_post_decoder_embed = self.post_decoder_embed_post(post_trg_seq) 197 | post_decoder_embed = post_trg_decoder_embed + post_post_decoder_embed 198 | 199 | if props is not None: 200 | type_embed = self.type_embed(torch.zeros((trg_shf_seq.shape[0], 1), dtype=torch.long, 201 | device=trg_shf_seq.device)) 202 | prop_embed = self.prop_embed(props) + type_embed 203 | post_decoder_embed = torch.cat([prop_embed, post_decoder_embed], 1) 204 | 205 | if torch.is_autocast_enabled(): 206 | post_decoder_embed = post_decoder_embed.half() 207 | 208 | # use transformer stacks 209 | if 'prob_encoder' in self.model_type: 210 | if post_trg_seq is not None: # training 211 | _, p_z_list, _ = self.post_encoder(post_encoder_embed, src_mask[:, None, :]) 212 | encoder_act, z_list, mask_encoder = self.encoder(seq_embed, src_mask[:, None, :], p_z_list=p_z_list) 213 | else: 214 | encoder_act, z_list, mask_encoder = self.encoder(seq_embed, src_mask[:, None, :], infer_mean=infer_mean) 215 | 216 | elif 'encoder' in self.model_type: 217 | encoder_act = self.encoder(seq_embed, src_mask[:, None, :]) 218 | 219 | if 'encoder_prob_decoder' == self.model_type: 220 | if post_trg_seq is not None: # training 221 | _, p_z_list, _ = self.post_decoder(post_decoder_embed, trg_shift_mask) 222 | decoder_act, z_list, mask_decoder = self.decoder(trg_shift_embed, trg_shift_mask, 223 | enc_act=encoder_act, enc_mask=src_mask[:, None, :], 224 | p_z_list=p_z_list) 225 | else: 226 | decoder_act, z_list, mask_decoder = self.decoder(trg_shift_embed, trg_shift_mask, 227 | enc_act=encoder_act, enc_mask=src_mask[:, None, :], 228 | infer_mean=infer_mean) 229 | 230 | elif 'encoder_decoder' == self.model_type: 231 | decoder_act = self.decoder(trg_shift_embed, trg_shift_mask, enc_act=encoder_act, 232 | enc_mask=src_mask[:, None, :]) 233 | 234 | elif 'prob_decoder' == self.model_type: 235 | if post_trg_seq is not None: # training 236 | _, p_z_list, _ = self.post_decoder(post_decoder_embed, trg_shift_mask) 237 | decoder_act, z_list, mask_decoder = self.decoder(trg_shift_embed, trg_shift_mask, p_z_list=p_z_list) 238 | else: 239 | decoder_act, z_list, mask_decoder = self.decoder(trg_shift_embed, trg_shift_mask, infer_mean=infer_mean) 240 | elif 'decoder' == self.model_type: 241 | decoder_act = self.decoder(trg_shift_embed, trg_shift_mask) 242 | 243 | if torch.is_autocast_enabled(): 244 | if self.encoder: 245 | assert encoder_act.dtype == torch.float16 246 | if self.decoder: 247 | assert decoder_act.dtype == torch.float16 248 | 249 | if self.decoder: 250 | output_act = decoder_act 251 | else: 252 | output_act = encoder_act 253 | 254 | if torch.is_autocast_enabled(): 255 | output_act = output_act.float() 256 | 257 | with amp.autocast(enabled=False): 258 | output_pred = self.output(self.output_ln(output_act)) 259 | 260 | if self.decoder and props is not None: 261 | output_pred = output_pred[:, -trg_shf_seq.shape[1]:, :] 262 | 263 | if self.probabilistic and post_trg_seq is not None: 264 | return output_pred, (z_list, p_z_list) 265 | elif output_latent and post_trg_seq is None: 266 | return output_pred, output_act 267 | else: 268 | return output_pred 269 | -------------------------------------------------------------------------------- /prob_transformer/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/module/__init__.py -------------------------------------------------------------------------------- /prob_transformer/module/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | 8 | def __init__(self, q_data_dim, m_data_dim, output_dim, num_head, zero_init=True, output_linear=True): 9 | super().__init__() 10 | assert q_data_dim % num_head == 0 11 | assert m_data_dim % num_head == 0 12 | self.key_dim = q_data_dim // num_head 13 | self.value_dim = m_data_dim // num_head 14 | 15 | self.key_dim_scaler = nn.Parameter(torch.FloatTensor([self.key_dim ** (-0.5)]), requires_grad=False) 16 | 17 | self.scale = nn.Parameter(torch.sqrt(torch.FloatTensor([self.key_dim])), requires_grad=False) 18 | 19 | self.zero_init = zero_init 20 | self.output_linear = output_linear 21 | 22 | self.num_head = num_head 23 | self.linear_q = nn.Linear(q_data_dim, q_data_dim, bias=False) 24 | self.linear_k = nn.Linear(m_data_dim, m_data_dim, bias=False) 25 | self.linear_v = nn.Linear(m_data_dim, m_data_dim, bias=False) 26 | 27 | if self.output_linear: 28 | self.linear_o = nn.Linear(num_head * self.value_dim, output_dim, bias=True) 29 | 30 | self.initialize() 31 | 32 | def initialize(self): 33 | 34 | nn.init.xavier_uniform_(self.linear_q.weight) 35 | nn.init.xavier_uniform_(self.linear_k.weight) 36 | nn.init.xavier_uniform_(self.linear_v.weight) 37 | 38 | if self.zero_init: 39 | nn.init.constant_(self.linear_o.weight, 0.0) 40 | else: 41 | nn.init.xavier_uniform_(self.linear_o.weight) 42 | nn.init.constant_(self.linear_o.bias, 0.0) 43 | 44 | def forward(self, q_data, m_data, mask): 45 | 46 | batch_size = q_data.size(0) 47 | N_q_seq = q_data.size(1) 48 | N_m_seq = m_data.size(1) 49 | q = self.linear_q(q_data).view(batch_size, N_q_seq, self.num_head, self.key_dim).permute(0, 2, 1, 50 | 3) * self.key_dim_scaler 51 | k = self.linear_k(m_data).view(batch_size, N_m_seq, self.num_head, self.value_dim).permute(0, 2, 3, 1) 52 | v = self.linear_v(m_data).view(batch_size, N_m_seq, self.num_head, self.value_dim).permute(0, 2, 1, 3) 53 | 54 | logits = torch.matmul(q, k) / self.scale 55 | 56 | if torch.is_autocast_enabled(): 57 | bias = (1e4 * (mask - 1.))[:, None, :, :] 58 | else: 59 | bias = (1e9 * (mask - 1.))[:, None, :, :] 60 | logits = logits + bias 61 | 62 | weights = F.softmax(logits, dim=-1) 63 | 64 | weighted_avg = torch.matmul(weights, v).permute(0, 2, 1, 3) 65 | 66 | if self.output_linear: 67 | output = self.linear_o(weighted_avg.reshape(batch_size, N_q_seq, self.num_head * self.value_dim)) 68 | else: 69 | output = weighted_avg.reshape(batch_size, N_q_seq, self.num_head * self.value_dim) 70 | 71 | return output 72 | 73 | 74 | class PreNormAttention(nn.Module): 75 | 76 | def __init__(self, model_dim, num_head, encoder=False, zero_init=True): 77 | super().__init__() 78 | 79 | self.model_dim = model_dim 80 | self.num_head = num_head 81 | self.encoder = encoder 82 | 83 | self.src_ln = nn.LayerNorm(model_dim) 84 | 85 | if encoder: 86 | self.enc_ln = nn.LayerNorm(model_dim) 87 | 88 | self.attn = Attention(q_data_dim=model_dim, m_data_dim=model_dim, output_dim=model_dim, num_head=num_head, 89 | zero_init=zero_init) 90 | self.initialize() 91 | 92 | def initialize(self): 93 | pass 94 | 95 | def forward(self, src_act, enc_act=None, mask=None): 96 | 97 | src_act = self.src_ln(src_act) 98 | if self.encoder: 99 | enc_act = self.enc_ln(enc_act) 100 | else: 101 | enc_act = src_act 102 | src_act = self.attn(src_act, enc_act, mask) 103 | return src_act 104 | -------------------------------------------------------------------------------- /prob_transformer/module/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PosEmbedding(nn.Module): 7 | def __init__(self, vocab, model_dim, max_len): 8 | super().__init__() 9 | self.max_len = max_len 10 | self.embed_seq = nn.Embedding(vocab, model_dim) 11 | self.scale = nn.Parameter(torch.sqrt(torch.FloatTensor([model_dim // 2])), requires_grad=False) 12 | self.embed_pair_pos = nn.Linear(max_len + 1, model_dim) 13 | 14 | def relative_position_encoding(self, src_seq): 15 | residue_index = torch.arange(src_seq.size()[1], device=src_seq.device).expand(src_seq.size()) 16 | rel_pos = F.one_hot(torch.clip(residue_index, min=0, max=self.max_len), self.max_len + 1).type( 17 | torch.float32).to(src_seq.device) 18 | pos_encoding = self.embed_pair_pos(rel_pos) 19 | return pos_encoding 20 | 21 | def forward(self, src_seq): 22 | seq_embed = self.embed_seq(src_seq) * self.scale 23 | seq_embed = seq_embed + self.relative_position_encoding(src_seq) 24 | return seq_embed 25 | -------------------------------------------------------------------------------- /prob_transformer/module/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FeedForward(nn.Module): 5 | 6 | def __init__(self, model_dim, ff_dim, zero_init=True): 7 | super(FeedForward, self).__init__() 8 | 9 | self.zero_init = zero_init 10 | 11 | self.input_norm = nn.LayerNorm(model_dim) 12 | self.linear_1 = nn.Linear(model_dim, ff_dim) 13 | self.linear_2 = nn.Linear(ff_dim, model_dim) 14 | self.act = nn.SiLU() 15 | 16 | self.initialize() 17 | 18 | def initialize(self): 19 | 20 | nn.init.kaiming_normal_(self.linear_1.weight) 21 | nn.init.constant_(self.linear_1.bias, 0.0) 22 | nn.init.constant_(self.linear_2.bias, 0.0) 23 | 24 | if self.zero_init: 25 | nn.init.constant_(self.linear_2.weight, 0.0) 26 | else: 27 | nn.init.xavier_normal_(self.linear_2.weight) 28 | 29 | def forward(self, x): 30 | 31 | x = self.input_norm(x) 32 | 33 | x = self.act(self.linear_1(x)) 34 | 35 | return self.linear_2(x) 36 | -------------------------------------------------------------------------------- /prob_transformer/module/geco_criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GECOLoss(nn.Module): 7 | 8 | def __init__(self, model, kappa, lagmul_rate, ma_decay): 9 | super(GECOLoss, self).__init__() 10 | 11 | self.register_buffer('kappa', torch.nn.Parameter(torch.FloatTensor([kappa]), requires_grad=False), persistent=True) 12 | self.decay = ma_decay 13 | 14 | self.lagmul_rate = lagmul_rate 15 | lagmul_init = torch.FloatTensor([1.0]) 16 | lagmul_init = torch.log(torch.exp(torch.sqrt(lagmul_init)) - 1) # inverse_softplus( sqrt(x) ) 17 | lagmul = nn.Parameter(lagmul_init, requires_grad=True) 18 | self.lagmul = self.scale_gradients(lagmul, -lagmul_rate) 19 | model.register_parameter("lagmul", self.lagmul) 20 | 21 | self.t = 0 22 | self.ce_ma = 0 23 | 24 | @staticmethod 25 | def scale_gradients(v, weights): 26 | def hook(g): 27 | return g * weights 28 | v.register_hook(hook) 29 | return v 30 | 31 | def _moving_average(self, ce_loss): 32 | if self.t == 0: 33 | self.ce_ma = ce_loss.detach() 34 | self.t += 1 35 | return ce_loss 36 | else: 37 | self.ce_ma = self.decay * self.ce_ma + (1 - self.decay) * ce_loss.detach() 38 | self.t += 1 39 | return ce_loss + (self.ce_ma - ce_loss).detach() 40 | 41 | def forward(self, crit_set, trg_set): 42 | 43 | z_lists, ce_loss = crit_set 44 | trg_seq, trg_len = trg_set 45 | 46 | mask = torch.arange(trg_seq.size()[1], device=trg_seq.device).expand(trg_seq.size()) 47 | mask = mask < trg_len[:, None] 48 | mask = mask.type(trg_seq.type()) 49 | 50 | if z_lists[0][0].mean.shape[1] != trg_seq.size()[1]: # correct in case of props or scaffold 51 | large_mask = torch.zeros(z_lists[0][0].mean.shape[:2], device=trg_seq.device).type(trg_seq.type()) 52 | large_mask[:, -trg_seq.size()[1]:] = mask 53 | mask = large_mask 54 | 55 | z_list, p_z_list = z_lists 56 | 57 | kl_list = [] 58 | mean_list = [] 59 | mean_max_list = [] 60 | stddev_list = [] 61 | stddev_max_list = [] 62 | 63 | for idx, (z, p_z) in enumerate(zip(z_list, p_z_list)): 64 | mean_list.append(torch.mean(z.mean).detach()) 65 | mean_max_list.append(torch.max(torch.abs(z.mean)).detach()) 66 | stddev_list.append(torch.mean(z.stddev).detach()) 67 | stddev_max_list.append(torch.max(torch.abs(z.stddev)).detach()) 68 | 69 | kl_dist = torch.distributions.kl_divergence(p_z, z) 70 | 71 | kl_dist = kl_dist.sum(-1) 72 | kl_dist = kl_dist * mask 73 | kl_dist = torch.sum(kl_dist, dim=-1) / trg_len 74 | kl_dist = kl_dist.mean() 75 | 76 | kl_list.append(kl_dist) 77 | 78 | kl_loss = torch.stack(kl_list, dim=-1).sum() 79 | ma_ce_loss = self._moving_average(ce_loss) 80 | rec_constraint = ma_ce_loss - self.kappa 81 | 82 | lamb = F.softplus(self.lagmul) ** 2 83 | 84 | loss = lamb * rec_constraint + kl_loss 85 | 86 | summary = {"mean_list": [k.detach() for k in mean_list], 87 | "stddev_list": [k.detach() for k in stddev_list], 88 | "mean_max_list": [k.detach() for k in mean_max_list], 89 | "stddev_max_list": [k.detach() for k in stddev_max_list], 90 | "kl_loss": kl_loss.detach(), 91 | "ce_loss": ce_loss.detach(), 92 | "lagmul": self.lagmul.detach(), 93 | "ma_ce_loss": ma_ce_loss.detach(), 94 | "rec_constraint": rec_constraint.detach(), 95 | "lamb": lamb.detach(), "t": self.t, } 96 | 97 | return loss, summary 98 | -------------------------------------------------------------------------------- /prob_transformer/module/mat_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from prob_transformer.module.embedding import PosEmbedding 5 | 6 | 7 | class ResNetBlock(nn.Module): 8 | def __init__(self, in_channel_size, out_channel_size, kernel, residual): 9 | super(ResNetBlock, self).__init__() 10 | 11 | self.residual = residual 12 | self.norm1 = nn.InstanceNorm2d(in_channel_size) 13 | self.norm2 = nn.InstanceNorm2d(in_channel_size) 14 | 15 | self.acti = nn.SiLU() 16 | if kernel == 1: 17 | self.conv1 = nn.Conv2d(in_channel_size, in_channel_size, kernel_size=1) 18 | self.conv2 = nn.Conv2d(in_channel_size, out_channel_size, kernel_size=1) 19 | else: 20 | self.conv1 = nn.Conv2d(in_channel_size, in_channel_size, kernel_size=kernel, padding=(kernel - 1) // 2) 21 | if in_channel_size == out_channel_size: 22 | self.conv2 = nn.Conv2d(in_channel_size, out_channel_size, kernel_size=kernel, padding=(kernel - 1) // 2) 23 | else: 24 | self.conv2 = nn.Conv2d(in_channel_size, out_channel_size, kernel_size=1) 25 | 26 | def initialize(self): 27 | nn.init.kaiming_normal_(self.conv1.weight) 28 | nn.init.constant_(self.conv2.weight, 0.0) 29 | nn.init.constant_(self.conv1.bias, 0.0) 30 | nn.init.constant_(self.conv1.bias, 0.0) 31 | 32 | def forward(self, x): 33 | 34 | x_hat = self.norm1(x) 35 | x_hat = self.acti(x_hat) 36 | x_hat = self.conv1(x_hat) 37 | x_hat = self.norm2(x_hat) 38 | x_hat = self.acti(x_hat) 39 | x_hat = self.conv2(x_hat) 40 | 41 | if self.residual: 42 | return x_hat + x 43 | else: 44 | return x_hat 45 | 46 | 47 | class SimpleMatrixHead(nn.Module): 48 | 49 | def __init__(self, src_vocab_size, latent_dim, dropout, model_dim, out_channels, 50 | res_layer, kernel, max_len): 51 | 52 | super(SimpleMatrixHead, self).__init__() 53 | 54 | self.row_latent_linear = nn.Linear(latent_dim, model_dim) 55 | self.col_latent_linear = nn.Linear(latent_dim, model_dim) 56 | 57 | self.latent_normal = nn.LayerNorm(latent_dim) 58 | 59 | self.row_src_embed = PosEmbedding(src_vocab_size, model_dim, max_len) 60 | self.col_src_embed = PosEmbedding(src_vocab_size, model_dim, max_len) 61 | self.row_pred_embed = PosEmbedding(13, model_dim, max_len) 62 | self.col_pred_embed = PosEmbedding(13, model_dim, max_len) 63 | 64 | conv_net_list = [] 65 | for _ in range(res_layer): 66 | conv_net_list.append(ResNetBlock(model_dim, model_dim, kernel, residual=True)) 67 | 68 | self.conv_net = nn.Sequential(*conv_net_list) 69 | 70 | self.generator = nn.Conv2d(model_dim, out_channels, kernel_size=1) 71 | 72 | self.initialize() 73 | 74 | def initialize(self): 75 | nn.init.kaiming_normal_(self.row_latent_linear.weight) 76 | nn.init.kaiming_normal_(self.col_latent_linear.weight) 77 | nn.init.constant_(self.row_latent_linear.bias, 0.0) 78 | nn.init.constant_(self.col_latent_linear.bias, 0.0) 79 | 80 | def forward(self, latent, src, pred, src_len): 81 | 82 | src_mask = self.make_mask(src_len) 83 | 84 | row_seq = self.row_src_embed(src) 85 | col_seq = self.col_src_embed(src) 86 | 87 | row_pred = self.row_pred_embed(pred) 88 | col_pred = self.col_pred_embed(pred) 89 | 90 | latent = self.latent_normal(latent) 91 | row_latent = self.row_latent_linear(latent) 92 | col_latent = self.col_latent_linear(latent) 93 | 94 | row_seq = row_seq.transpose(1, 2) 95 | col_seq = col_seq.transpose(1, 2) 96 | 97 | row_pred = row_pred.transpose(1, 2) 98 | col_pred = col_pred.transpose(1, 2) 99 | 100 | row_latent = row_latent.transpose(1, 2) 101 | col_latent = col_latent.transpose(1, 2) 102 | 103 | row_seq = row_seq.unsqueeze(2).repeat(1, 1, row_seq.shape[2], 1) 104 | col_seq = col_seq.unsqueeze(3).repeat(1, 1, 1, col_seq.shape[2]) 105 | 106 | row_pred = row_pred.unsqueeze(2).repeat(1, 1, row_pred.shape[2], 1) 107 | col_pred = col_pred.unsqueeze(3).repeat(1, 1, 1, col_pred.shape[2]) 108 | 109 | row_latent = row_latent.unsqueeze(2).repeat(1, 1, row_latent.shape[2], 1) 110 | col_latent = col_latent.unsqueeze(3).repeat(1, 1, 1, col_latent.shape[2]) 111 | 112 | latent = row_seq + col_seq + row_pred + col_pred + row_latent + col_latent * src_mask 113 | 114 | output_mat = self.conv_net(latent) 115 | 116 | output_mat = self.generator(output_mat) 117 | 118 | return output_mat.permute(0, 2, 3, 1), src_mask 119 | 120 | def make_mask(self, src_len): 121 | with torch.no_grad(): 122 | max_len = torch.max(src_len).item() 123 | mask = [] 124 | for l in src_len: 125 | m = torch.ones([max_len, max_len]).to(src_len.device) 126 | m = torch.triu(m, diagonal=1) 127 | m[l:, :] = 0 128 | m[:, l:] = 0 129 | mask.append(m) 130 | return torch.stack(mask, dim=0).unsqueeze(1) 131 | -------------------------------------------------------------------------------- /prob_transformer/module/optim_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class OptiMaster(): 5 | def __init__(self, model, epochs, iter_per_epoch, optimizer, scheduler, warmup_epochs, lr_low, lr_high, beta1, 6 | beta2, weight_decay): 7 | 8 | self.lr_low = lr_low 9 | self.lr_high = lr_high 10 | self.scheduler = scheduler 11 | self.weight_decay = weight_decay 12 | 13 | self.iter_per_epoch = iter_per_epoch 14 | self.epochs = epochs 15 | self.model = model 16 | 17 | self.warmup_low = 1e-9 18 | self.epoch = -1 19 | 20 | init_lr = self.lr_high 21 | 22 | self.optimizer = self._get_optimizer(optimizer, model, lr=init_lr, beta1=beta1, beta2=beta2, 23 | weight_decay=weight_decay) 24 | 25 | if warmup_epochs > 0: 26 | self.warmup_steps = iter_per_epoch * (warmup_epochs) 27 | lr_func = lambda step: step / self.warmup_steps 28 | warmup_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func) 29 | self.warmup_schedule = warmup_schedule 30 | self.warmup_epochs = warmup_epochs # + 1 31 | else: 32 | self.warmup_epochs = 0 33 | self.train_epochs = epochs - self.warmup_epochs 34 | 35 | max_train_epoch = self.train_epochs 36 | 37 | main_schedule = self._get_schedule(scheduler, max_epoch=max_train_epoch) 38 | self.main_schedule = main_schedule 39 | 40 | if warmup_epochs > 0: 41 | self.optimizer.param_groups[0]['lr'] = self.warmup_low 42 | 43 | def epoch_step(self, epoch): 44 | self.epoch = epoch # + 1 45 | 46 | if self.epoch < self.warmup_epochs - 1: 47 | pass 48 | elif self.epoch > self.warmup_epochs - 1: 49 | self.main_schedule.step() 50 | 51 | def train_step(self): 52 | if self.epoch < self.warmup_epochs - 1: 53 | self.warmup_schedule.step() 54 | 55 | @property 56 | def lr(self): 57 | return self.optimizer.param_groups[0]['lr'] 58 | 59 | def config_weight_decay(self, model): 60 | 61 | decay = set() 62 | no_decay = set() 63 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d) 64 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 65 | for mn, m in model.named_modules(): 66 | for pn, p in m.named_parameters(): 67 | fpn = '%s.%s' % (mn, pn) if mn else pn 68 | if pn.endswith('bias') or ('bias' in pn): 69 | no_decay.add(fpn) 70 | elif (pn.endswith('weight') or ('weight' in pn)) and isinstance(m, whitelist_weight_modules): 71 | decay.add(fpn) 72 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 73 | no_decay.add(fpn) 74 | elif pn.endswith('scale') or pn.endswith('key_dim_scaler'): 75 | no_decay.add(fpn) 76 | elif 'lagmul' in pn: 77 | no_decay.add(fpn) 78 | 79 | param_dict = {pn: p for pn, p in model.named_parameters()} 80 | 81 | optim_groups = [ 82 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay}, 83 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 84 | ] 85 | return optim_groups 86 | 87 | def _get_optimizer(self, optim_name, model, lr, beta1, beta2, weight_decay): 88 | 89 | if self.weight_decay == 0 or self.weight_decay == False: 90 | params = model.parameters() 91 | else: 92 | params = self.config_weight_decay(model) 93 | 94 | if optim_name == "adam": 95 | return torch.optim.Adam(params, lr=lr, betas=(beta1, beta2), eps=1e-9, weight_decay=weight_decay) 96 | elif optim_name == "adamW": 97 | return torch.optim.AdamW(params, lr=lr, betas=(beta1, beta2), eps=1e-9, weight_decay=weight_decay) 98 | elif optim_name == "rmsprop": 99 | return torch.optim.RMSprop(params, lr=lr, alpha=0.98, momentum=0.1, eps=1e-9, weight_decay=weight_decay) 100 | 101 | def _get_schedule(self, schedule_name, max_epoch): 102 | if schedule_name == "step": 103 | train_gamma = (self.lr_low / self.lr_high) ** (1 / max_epoch) 104 | return torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=train_gamma) 105 | 106 | elif schedule_name == "linear": 107 | lr_func = lambda epoch: (self.lr_low / self.lr_high - 1) * epoch / max_epoch + 1 108 | return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func) 109 | 110 | elif schedule_name == "inv_sqrt": 111 | lr_func = lambda epoch: self.warmup_steps ** 0.5 / ( 112 | (self.warmup_epochs + epoch) * self.iter_per_epoch) ** 0.5 113 | return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func) 114 | 115 | elif schedule_name == "const": 116 | lr_func = lambda epoch: 1 117 | return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func) 118 | 119 | elif schedule_name == "cosine": 120 | return torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, max_epoch, eta_min=self.lr_low) 121 | -------------------------------------------------------------------------------- /prob_transformer/module/probabilistic_forward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ProbabilisticForward(nn.Module): 7 | 8 | def __init__(self, model_dim, z_dim, last_layer=False, softplus=False, zero_init=True): 9 | super(ProbabilisticForward, self).__init__() 10 | 11 | self.last_layer = last_layer 12 | self.softplus = softplus 13 | self.zero_init = zero_init 14 | 15 | self.input_norm = nn.LayerNorm(model_dim) 16 | 17 | self.linear_z1 = nn.Linear(model_dim, z_dim) 18 | self.act_z = nn.SiLU() 19 | self.linear_z2_mean = nn.Linear(z_dim, z_dim) 20 | self.linear_z2_logvar = nn.Linear(z_dim, z_dim) 21 | 22 | if not last_layer: 23 | self.linear_out = nn.Linear(z_dim, model_dim) 24 | 25 | self.initialize() 26 | 27 | def initialize(self): 28 | 29 | nn.init.kaiming_normal_(self.linear_z1.weight) 30 | nn.init.constant_(self.linear_z2_mean.weight, 0.0) 31 | 32 | nn.init.normal_(self.linear_z2_logvar.weight, mean=0, 33 | std=0.01 * np.sqrt( 34 | 2 / (self.linear_z2_logvar.weight.shape[0] * self.linear_z2_logvar.weight.shape[1]))) 35 | 36 | nn.init.constant_(self.linear_z1.bias, 0.0) 37 | nn.init.constant_(self.linear_z2_mean.bias, 0.0) 38 | nn.init.constant_(self.linear_z2_logvar.bias, 0.0) 39 | 40 | if not self.last_layer: 41 | nn.init.constant_(self.linear_out.bias, 0.0) 42 | 43 | if self.zero_init: 44 | nn.init.constant_(self.linear_out.weight, 0.0) 45 | else: 46 | nn.init.xavier_normal_(self.linear_out.weight) 47 | 48 | def forward(self, x, p_z=None, infer_mean=False): 49 | 50 | z_raw_n = self.input_norm(x) 51 | z_raw_l = self.linear_z1(z_raw_n) 52 | z_raw = self.act_z(z_raw_l) 53 | 54 | z_mean = self.linear_z2_mean(z_raw) 55 | logvar = self.linear_z2_logvar(z_raw) 56 | 57 | z_std = torch.exp(logvar) 58 | 59 | z = torch.distributions.Normal(z_mean, z_std) 60 | 61 | if p_z is not None: 62 | if infer_mean: 63 | z_out = p_z.mean 64 | else: 65 | z_out = p_z.rsample() 66 | else: 67 | if infer_mean: 68 | z_out = z.mean 69 | else: 70 | z_out = z.rsample() 71 | 72 | if self.last_layer: 73 | out = torch.zeros_like(x) 74 | else: 75 | out = self.linear_out(z_out) 76 | return out, z 77 | -------------------------------------------------------------------------------- /prob_transformer/module/probformer_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.module.feed_forward import FeedForward 4 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.module.probabilistic_forward import ProbabilisticForward 5 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.module.attention import PreNormAttention 6 | 7 | 8 | class ProbFormerBlock(nn.Module): 9 | 10 | def __init__(self, model_dim, num_head, ff_factor, z_factor, dropout, zero_init, cross_attention, 11 | probabilistic, last_layer): 12 | super().__init__() 13 | 14 | ff_dim = int(ff_factor * model_dim) 15 | z_dim = int(model_dim * z_factor) 16 | 17 | self.cross_attention = cross_attention 18 | self.probabilistic = probabilistic 19 | 20 | self.dropout = nn.Dropout(p=dropout) 21 | self.self_attn = PreNormAttention(model_dim, num_head, encoder=False, zero_init=zero_init) 22 | 23 | if cross_attention: 24 | self.coder_attn = PreNormAttention(model_dim, num_head, encoder=True, zero_init=zero_init) 25 | 26 | if probabilistic: 27 | self.prob_layer = ProbabilisticForward(model_dim, z_dim, 28 | last_layer=last_layer, 29 | softplus=False, zero_init=zero_init) 30 | self.transition = FeedForward(model_dim, ff_dim, zero_init) 31 | else: 32 | self.transition = FeedForward(model_dim, ff_dim, zero_init) 33 | 34 | def forward(self, src_act, src_mask, enc_act=None, enc_mask=None, p_z=None, infer_mean=False): 35 | 36 | src_act = src_act + self.dropout(self.self_attn(src_act, enc_act=None, mask=src_mask)) 37 | 38 | if self.cross_attention: 39 | src_act = src_act + self.dropout(self.coder_attn(src_act, enc_act=enc_act, mask=enc_mask)) 40 | 41 | if self.probabilistic: 42 | src_act = src_act + self.dropout(self.transition(src_act)) 43 | act_z, z = self.prob_layer(src_act, p_z, infer_mean) 44 | src_act = src_act + act_z 45 | return src_act, z 46 | else: 47 | src_act = src_act + self.dropout(self.transition(src_act)) 48 | return src_act 49 | -------------------------------------------------------------------------------- /prob_transformer/module/probformer_stack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.module.probformer_block import ProbFormerBlock 5 | 6 | 7 | class ProbFormerStack(nn.Module): 8 | 9 | def __init__(self, n_layers, model_dim, num_head, ff_factor, z_factor, dropout, zero_init, 10 | cross_attention, posterior, prob_layer): 11 | """Builds Attention module. 12 | """ 13 | super().__init__() 14 | 15 | self.posterior = posterior 16 | self.prob_layer = prob_layer 17 | 18 | module_list = [] 19 | for idx in range(n_layers): 20 | last_layer = posterior and idx == max(prob_layer) 21 | layer = ProbFormerBlock(model_dim=model_dim, num_head=num_head, ff_factor=ff_factor, z_factor=z_factor, 22 | dropout=dropout, zero_init=zero_init, 23 | cross_attention=cross_attention, probabilistic=idx in prob_layer, 24 | last_layer=last_layer) 25 | module_list.append(layer) 26 | self.layers = nn.ModuleList(module_list) 27 | 28 | def forward(self, src_act, src_mask, enc_act=None, enc_mask=None, p_z_list=None, infer_mean=False): 29 | 30 | z_list = [] 31 | p_z_index = 0 32 | mask_list = [] 33 | 34 | for idx, layer in enumerate(self.layers): 35 | if idx in self.prob_layer: 36 | if p_z_list is not None: 37 | src_act_new, z = layer(src_act, src_mask, enc_act, enc_mask, 38 | p_z=p_z_list[p_z_index], 39 | infer_mean=False) 40 | p_z_index = p_z_index + 1 41 | else: 42 | 43 | src_act_new, z = layer(src_act, src_mask, enc_act, enc_mask, infer_mean=infer_mean) 44 | z_list.append(z) 45 | mask_list.append(src_mask[:, 0, :].detach()) 46 | else: 47 | src_act_new = layer(src_act, src_mask, enc_act, enc_mask) 48 | src_act = src_act + src_act_new 49 | 50 | if len(self.prob_layer) > 0: 51 | return src_act, z_list, torch.stack(mask_list, dim=0) 52 | else: 53 | return src_act 54 | -------------------------------------------------------------------------------- /prob_transformer/routine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/routine/__init__.py -------------------------------------------------------------------------------- /prob_transformer/routine/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.cuda.amp as amp 5 | from prob_transformer.utils.summary import SummaryDict 6 | 7 | 8 | def run_epoch(rank, cfg, data_iter, model, geco_criterion=None, opti=None): 9 | model_type = cfg.model.model_type 10 | 11 | if opti is None: 12 | model.eval() 13 | is_train = False 14 | else: 15 | model.train() 16 | is_train = True 17 | 18 | epoch_summary = SummaryDict() 19 | 20 | batch_size_list = [] 21 | seq_len_list = [] 22 | 23 | if cfg.train.amp and cfg.train.grad_scale: 24 | scaler = amp.GradScaler(init_scale=cfg.train.grad_scale, growth_factor=2.0, backoff_factor=0.5, 25 | growth_interval=2000, enabled=True) 26 | 27 | criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=-1).to(rank) 28 | 29 | with torch.set_grad_enabled(is_train): 30 | for i, batch in enumerate(data_iter): 31 | 32 | if is_train and i > cfg.train.iter_per_epoch: 33 | break 34 | 35 | if is_train: 36 | for param in model.parameters(): 37 | param.grad = None 38 | 39 | if "encoder" in model_type: 40 | batch_size_list.append(batch.src_seq.size()[0]) 41 | seq_len_list.append(batch.src_seq.size()[1]) 42 | else: 43 | batch_size_list.append(batch.trg_seq.size()[0]) 44 | seq_len_list.append(batch.trg_seq.size()[1]) 45 | 46 | with amp.autocast(enabled=cfg.train.amp): 47 | 48 | if "encoder" in model_type: 49 | if "prob" in model_type and is_train: 50 | pred_seq, z_lists = model(batch.src_seq, batch.src_len, 51 | post_trg_seq=batch.post_seq, 52 | infer_mean=False) 53 | else: 54 | pred_seq = model(batch.src_seq, batch.src_len, infer_mean=True) 55 | 56 | elif "decoder" in model_type: # decoder only 57 | props, scaffold = None, None 58 | if cfg.data.type == "mol": 59 | if cfg.data.mol.props: 60 | props = batch.props 61 | 62 | if "prob" in model_type and is_train: 63 | pred_seq, z_lists = model(post_trg_seq=batch.post_seq, 64 | trg_shf_seq=batch.trg_shf_seq, 65 | trg_len=batch.trg_len, 66 | props=props, infer_mean=False) 67 | else: 68 | pred_seq = model(trg_shf_seq=batch.trg_shf_seq, trg_len=batch.trg_len, 69 | props=props, infer_mean=True) 70 | 71 | sequence_loss = criterion(pred_seq.contiguous().view(-1, pred_seq.size(-1)), 72 | batch.trg_seq.contiguous().view(-1)).view(batch.trg_seq.shape) 73 | sequence_loss = torch.sum(sequence_loss, dim=-1) / batch.trg_len 74 | sequence_loss = sequence_loss.mean() 75 | 76 | if is_train and "prob" in model_type: 77 | epoch_summary["pre_geco_loss"] = sequence_loss.detach() 78 | 79 | crit_set = z_lists, sequence_loss 80 | trg_set = batch.trg_seq, batch.trg_len 81 | geco_loss, summary = geco_criterion(crit_set, trg_set) 82 | 83 | epoch_summary["geco_loss"] = geco_loss.detach() 84 | loss = geco_loss 85 | else: 86 | loss = sequence_loss 87 | 88 | if opti is not None and 0 < loss.cpu().item(): 89 | 90 | if cfg.train.amp and cfg.train.grad_scale: 91 | loss = scaler.scale(loss) 92 | loss.backward() 93 | scaler.unscale_(opti.optimizer) 94 | else: 95 | loss.backward() 96 | 97 | if cfg.optim.clip_grad: 98 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.clip_grad) 99 | 100 | if cfg.train.amp and cfg.train.grad_scale: 101 | scaler.step(opti.optimizer) 102 | scaler.update() 103 | else: 104 | opti.optimizer.step() 105 | 106 | opti.train_step() 107 | opti.optimizer.zero_grad() 108 | 109 | epoch_summary["loss"] = loss.detach() 110 | epoch_summary["step"] = i 111 | 112 | if "prob" in model_type and is_train: 113 | epoch_summary(summary) 114 | 115 | loss = np.mean(epoch_summary['loss']) 116 | step = np.max(epoch_summary['step']) 117 | batch = np.mean(batch_size_list) 118 | seq_len = np.mean(seq_len_list) 119 | stats = {"mean_batch_size": batch, "step": step, "mean_seq_len": seq_len, "loss": loss} 120 | 121 | if "prob" in model_type and is_train: 122 | stats['ma_ce_loss'] = np.mean(epoch_summary['ma_ce_loss']) 123 | stats['rec_constraint'] = np.mean(epoch_summary['rec_constraint']) 124 | stats['lamb'] = np.mean(epoch_summary['lamb']) 125 | 126 | return stats, epoch_summary 127 | -------------------------------------------------------------------------------- /prob_transformer/train_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import pathlib 3 | import socket 4 | import numpy as np 5 | import torch 6 | 7 | from prob_transformer.utils.supporter import Supporter 8 | from prob_transformer.utils.summary import SummaryDict 9 | from prob_transformer.utils.config_init import cinit 10 | from prob_transformer.utils.torch_utils import count_parameters 11 | 12 | from prob_transformer.module.optim_builder import OptiMaster 13 | from prob_transformer.model.probtransformer import ProbTransformer 14 | from prob_transformer.data.iterator import MyIterator 15 | from prob_transformer.data.rna_handler import RNAHandler 16 | from prob_transformer.data.ssd_handler import SSDHandler 17 | from prob_transformer.data.mol_handler import MolHandler 18 | 19 | from prob_transformer.module.geco_criterion import GECOLoss 20 | from prob_transformer.routine.evaluation import run_evaluation 21 | from prob_transformer.routine.training import run_epoch 22 | 23 | device = 0 if torch.cuda.is_available() else "cpu" 24 | 25 | 26 | def train_prob_transformer(config): 27 | expt_dir = pathlib.Path("experiment") 28 | sup = Supporter(experiments_dir=expt_dir, config_dict=config) 29 | 30 | cfg = sup.get_config() 31 | log = sup.get_logger() 32 | ckp = sup.ckp 33 | log.print_config(cfg) 34 | 35 | rank = 0 36 | 37 | np.random.seed(cfg.train.seed + rank) 38 | torch.manual_seed(cfg.train.seed + rank) 39 | 40 | if torch.cuda.is_available(): 41 | torch.cuda.set_device(rank) 42 | torch.cuda.manual_seed(cfg.train.seed + rank) 43 | else: 44 | rank = 'cpu' 45 | 46 | torch.backends.cudnn.deterministic = False 47 | torch.backends.cudnn.benchmark = True 48 | 49 | log.log(f"rank {rank} ### START TRAINING ### at {socket.gethostname()}") 50 | 51 | ############################################################ 52 | ####### DATA ITERATOR ######## 53 | ############################################################ 54 | log.log(f"### load data", rank=rank) 55 | 56 | num_props = False 57 | 58 | if cfg.data.type == "rna": 59 | 60 | ignore_index = -1 61 | pad_index = 0 62 | 63 | train_data = cinit(RNAHandler, cfg.data.rna, sub_set='train', prob_training="prob" in cfg.model.model_type, 64 | device=rank, seed=cfg.data.seed, ignore_index=ignore_index) 65 | 66 | valid_data = cinit(RNAHandler, cfg.data.rna, sub_set='valid', prob_training=False, device=rank, 67 | seed=cfg.data.seed, ignore_index=ignore_index) 68 | 69 | test_data = cinit(RNAHandler, cfg.data.rna, sub_set='test', prob_training=False, device=rank, 70 | seed=cfg.data.seed, ignore_index=ignore_index) 71 | 72 | seq_vocab_size = train_data.seq_vocab_size 73 | trg_vocab_size = train_data.struct_vocab_size 74 | 75 | elif cfg.data.type == "ssd": 76 | 77 | ignore_index = -1 78 | pad_index = 0 79 | 80 | train_data = cinit(SSDHandler, cfg.data.ssd, sample_amount=cfg.data.ssd.sample_amount, device=rank, 81 | pre_src_vocab=None, pre_trg_vocab=None, token_dict=None) 82 | valid_data = cinit(SSDHandler, cfg.data.ssd, sample_amount=cfg.data.ssd.sample_amount // 10, device=rank, 83 | pre_src_vocab=train_data.pre_src_vocab, pre_trg_vocab=train_data.pre_trg_vocab, 84 | token_dict=train_data.token_dict, seed=cfg.data.ssd.seed+1) 85 | test_data = cinit(SSDHandler, cfg.data.ssd, sample_amount=cfg.data.ssd.sample_amount // 10, device=rank, 86 | pre_src_vocab=train_data.pre_src_vocab, pre_trg_vocab=train_data.pre_trg_vocab, 87 | token_dict=train_data.token_dict, seed=cfg.data.ssd.seed+2) 88 | 89 | seq_vocab_size = train_data.source_vocab_size 90 | trg_vocab_size = train_data.target_vocab_size 91 | 92 | 93 | elif cfg.data.type == "mol": 94 | 95 | train_data = cinit(MolHandler, cfg.data.mol, split="train", device=rank) 96 | valid_data = cinit(MolHandler, cfg.data.mol, split="valid", device=rank) 97 | test_data = cinit(MolHandler, cfg.data.mol, split="test", device=rank) 98 | 99 | ignore_index = -1 100 | pad_index = train_data.ignore_index 101 | 102 | if isinstance(cfg.data.mol.props, List): 103 | num_props = len(cfg.data.mol.props) 104 | 105 | if "decoder" in cfg.model.model_type: 106 | seq_vocab_size = train_data.target_vocab_size 107 | else: 108 | seq_vocab_size = 1 109 | trg_vocab_size = train_data.target_vocab_size 110 | log(f"trg_vocab_size: {trg_vocab_size}") 111 | 112 | else: 113 | raise UserWarning(f"data type unknown: {cfg.data.type}") 114 | 115 | log.log(f"### load iterator", rank=rank) 116 | 117 | train_iter = MyIterator(data_handler=train_data, batch_size=cfg.data.batch_size, repeat=True, shuffle=True, 118 | batching=True, pre_sort_samples=True, 119 | device=rank, seed=cfg.data.seed + rank, ignore_index=ignore_index, pad_index=pad_index) 120 | 121 | valid_iter = MyIterator(data_handler=valid_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 122 | batching=True, pre_sort_samples=False, 123 | device=rank, seed=cfg.data.seed + rank, ignore_index=ignore_index, pad_index=pad_index) 124 | 125 | test_iter = MyIterator(data_handler=test_data, batch_size=cfg.data.batch_size, repeat=False, shuffle=False, 126 | batching=False, pre_sort_samples=False, 127 | device=rank, seed=cfg.data.seed + rank, ignore_index=ignore_index, pad_index=pad_index) 128 | 129 | log.log("train_set_size", train_iter.set_size, rank=rank) 130 | log.log("valid_set_size", valid_iter.set_size, rank=rank) 131 | log.log("test_set_size", test_iter.set_size, rank=rank) 132 | 133 | log("src_vocab_len", seq_vocab_size) 134 | log("tgt_vocab_len", trg_vocab_size) 135 | 136 | ############################################################ 137 | ####### BUILD MODEL ######## 138 | ############################################################ 139 | model = cinit(ProbTransformer, cfg.model, seq_vocab_size=seq_vocab_size, trg_vocab_size=trg_vocab_size, 140 | props=num_props) 141 | 142 | log.log("model_parameters", count_parameters(model.parameters()), rank=rank) 143 | 144 | model = model.to(rank) 145 | 146 | train_summary = SummaryDict() 147 | eval_summary = SummaryDict() 148 | start_epoch = 0 149 | 150 | if "prob" in cfg.model.model_type: 151 | geco_criterion = cinit(GECOLoss, cfg.geco_criterion, model=model).to(rank) 152 | else: 153 | geco_criterion = None 154 | 155 | log.log(f"rank {rank} start GPU training ") 156 | 157 | log.log("trainable_parameters", count_parameters(model.parameters()), rank=rank) 158 | 159 | optima = cinit(OptiMaster, cfg.optim, model=model, epochs=cfg.train.epochs, iter_per_epoch=cfg.train.iter_per_epoch) 160 | 161 | ############################################################ 162 | ####### START TRAINING ######## 163 | ############################################################ 164 | log.start_timer(f"total", rank=rank) 165 | for epoch in range(start_epoch, cfg.train.epochs + 1): 166 | log.start_timer(f"epoch", rank=rank) 167 | 168 | log(f"#{rank}## START epoch {epoch} " + '#' * 36) 169 | 170 | if epoch != 0: # validate untrained model before start train 171 | log.start_timer(f"train", rank=rank) 172 | log("## Start training") 173 | stats, summary = run_epoch(rank, cfg, train_iter, model, geco_criterion=geco_criterion, opti=optima) 174 | for name, value in stats.items(): 175 | log(f'train_' + name, value, epoch, rank=rank) 176 | 177 | train_summary(summary) 178 | eval_summary["step_count"] = stats['step'] 179 | eval_summary["step"] = epoch 180 | eval_summary["train_loss"] = stats['loss'] 181 | log.timer(f"train", epoch, rank=rank) 182 | 183 | ### Update Kappa 184 | if cfg.geco_criterion.kappa_adaption and model.probabilistic: 185 | kappa = geco_criterion.kappa.data.item() 186 | if summary["rec_constraint"].mean() < 0 and summary["lamb"].mean() < 1: 187 | kappa = kappa + summary["rec_constraint"].mean() 188 | log(f"kappa_update: rc mean {summary['rec_constraint'].mean():4.3f} new_kappa {kappa:6.6f}", 189 | rank=rank) 190 | geco_criterion.kappa.data = torch.FloatTensor([kappa]).to(rank) 191 | else: 192 | eval_summary["step_count"] = 0 193 | eval_summary["step"] = epoch 194 | eval_summary["train_loss"] = 0 195 | 196 | optima.epoch_step(epoch - 1) 197 | 198 | if model.probabilistic: 199 | eval_summary["kappa"] = geco_criterion.kappa 200 | eval_summary['learning_rate'] = optima.lr 201 | 202 | if cfg.data.type != 'mol': 203 | log.start_timer(f"valid", rank=rank) 204 | log("## Start validation") 205 | stats, summary = run_epoch(rank, cfg, valid_iter, model) 206 | for name, value in stats.items(): log(f'valid_' + name, value, epoch, rank=rank) 207 | eval_summary["valid_loss"] = stats['loss'] 208 | log.timer(f"valid", epoch, rank=rank) 209 | 210 | if epoch % cfg.train.eval_freq == 0 and (rank == 0 or rank == 'cpu'): 211 | log.start_timer(f"eval", rank=rank) 212 | log("## Start valid evaluation") 213 | score_dict_valid = run_evaluation(cfg, valid_iter, model) 214 | for name, score in score_dict_valid.items(): 215 | log(f"{name}_valid", score, epoch, rank=rank) 216 | eval_summary[f"{name}_valid"] = score 217 | log.timer(f"eval", epoch, rank=rank) 218 | 219 | train_summary.save(ckp.dir / "train_summary.npy") 220 | eval_summary.save(ckp.dir / "eval_summary.npy") 221 | log.save_to_json(rank=rank) 222 | 223 | log.timer(f"epoch", epoch, rank=rank) 224 | if cfg.data.type != 'mol': 225 | if stats['loss'] == np.nan: 226 | log(f"### STOP TRAINING - loss is NaN -> {stats['loss']}", rank=rank) 227 | break 228 | 229 | if epoch % cfg.train.save_freq == 0 and epoch != 0 and rank == 0 and cfg.expt.save_model: 230 | log(f"#{rank}## Save Model - number {epoch}", rank=rank) 231 | 232 | checkpoint = {'state_dict': model.state_dict(), 'optimizer': optima.optimizer.state_dict(), 233 | "config": cfg.get_dict} 234 | torch.save(checkpoint, ckp.dir / f"checkpoint_{epoch}.pth") 235 | 236 | ########################################################### 237 | ###### TEST MODEL ######## 238 | ########################################################### 239 | if rank == 0 or rank == 'cpu': 240 | log(f"#{rank}## FINAL TEST MODEL") 241 | log.start_timer(f"test") 242 | score_dict = run_evaluation(cfg, test_iter, model) # , threshold=score_dict_valid['threshold']) 243 | 244 | for name, score in score_dict.items(): 245 | log(f"{name}_test", score, epoch) 246 | eval_summary[f"{name}_final"] = score 247 | log.timer(f"test", epoch) 248 | log.save_to_json(rank=rank) 249 | eval_summary.save(ckp.dir / "eval_summary.npy") 250 | 251 | if cfg.expt.save_model: 252 | log(f"#{rank}## Save Model - final", rank=rank) 253 | 254 | checkpoint = {'state_dict': model.state_dict(), 'optimizer': optima.optimizer.state_dict(), 255 | "config": cfg.get_dict} 256 | torch.save(checkpoint, ckp.dir / f"checkpoint_final.pth") 257 | 258 | log.timer(f"total", rank=rank) 259 | log(f"#{rank}## END TRAINING") 260 | 261 | 262 | if __name__ == "__main__": 263 | import argparse 264 | import yaml 265 | 266 | parser = argparse.ArgumentParser(description='Train the model as specified by the given configuration file.') 267 | parser.add_argument('-c', '--config', type=str, help='a configuration file') 268 | args = parser.parse_args() 269 | 270 | if args.config: 271 | config = yaml.load(open(pathlib.Path.cwd() / pathlib.Path(args.config)), Loader=yaml.Loader) 272 | else: 273 | 274 | config = { 275 | "expt": { 276 | "experiment_name": "test_training", 277 | "save_model": True, 278 | }, 279 | "train": { 280 | "eval_freq": 1, # every * epoch will the model be evaluated 281 | "save_freq": 1, 282 | "seed": 1, # random seed of numpy and torch 283 | "epochs": 10, # epoch to train 284 | "n_sampling": 1, # use dropout sampling during inference 285 | "iter_per_epoch": 100, # number of samples drawn during evaluation 286 | "amp": False, # automatic mixed precision 287 | "grad_scale": 2 ** 16, # 2**16 288 | }, 289 | "geco_criterion": { 290 | "kappa": 0.1, 291 | "kappa_adaption": True, 292 | "lagmul_rate": 0.01, 293 | "ma_decay": 0.99, 294 | }, 295 | "model": { 296 | "model_type": 'prob_encoder', 297 | # "model_type": 'prob_decoder', 298 | "model_dim": 256, # hidden dimension of transformer 299 | "max_len": 100, # Maximum length an input sequence can have. Required for positional encoding. 300 | "n_layers": 4, # number of transformer layers 301 | "num_head": 4, # number of heads per layer 302 | "ff_factor": 4, # hidden dim * ff_factor = size of feed-forward layer 303 | "z_factor": 1.0, # hidden dim * z_factor = size of prob layer 304 | "dropout": 0.1, 305 | "prob_layer": "all", # "middle", # middle all 306 | "zero_init": True, # init last layer per block before each residual connection 307 | }, 308 | "optim": { 309 | "optimizer": "adam", # adam adamW rmsprop adabelief 310 | "scheduler": "cosine", # cosine linear 311 | "warmup_epochs": 1, 312 | "lr_low": 0.0001, 313 | "lr_high": 0.0005, 314 | "clip_grad": 100, 315 | "beta1": 0.9, 316 | "beta2": 0.98, 317 | "weight_decay": 1e-10, 318 | }, 319 | "data": { 320 | "type": 'rna', # rna, ssd, mol 321 | "batch_size": 500, 322 | "seed": 1, 323 | "rna": { 324 | "df_path": 'data/rna_data.plk', 325 | "df_set_name": 'train', 326 | "min_length": 20, 327 | "max_length": 100, 328 | "similarity": 80, 329 | }, 330 | "ssd": { 331 | "min_len": 15, 332 | "max_len": 30, 333 | "sample_amount": 10_000, 334 | "trg_vocab_size": 50, 335 | "src_vocab_size": 50, 336 | "sentence_len": 3, 337 | "n_sentence": 100, 338 | "sentence_variations": 10, 339 | "seed": 100, 340 | "n_eval": 10, 341 | }, 342 | "mol": { 343 | "data_dir": "data/guacamol2.csv", 344 | "min_length": 10, 345 | "max_length": 100, 346 | 'props': ["tpsa", "logp", "sas"], 347 | "gen_size": 100, 348 | "block_size": 100, 349 | "seed": 1, 350 | }, 351 | }, 352 | } 353 | train_prob_transformer(config=config) 354 | -------------------------------------------------------------------------------- /prob_transformer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .handler.config import ConfigHandler 2 | from .handler.folder import FolderHandler 3 | -------------------------------------------------------------------------------- /prob_transformer/utils/config_init.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import inspect 3 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.utils.handler.config import ConfigHandler, AttributeDict 4 | 5 | 6 | def cinit(instance, config, **kwargs): 7 | """ 8 | Instantiates a class by selecting the required args from a ConfigHandler. Omits wrong kargs 9 | @param instance: class 10 | @param config: ConfigHandler object contains class args 11 | @param kwargs: kwargs besides/replacing ConfigHandler args 12 | @return: class object 13 | """ 14 | 15 | if isinstance(instance, type): 16 | instance_args = inspect.signature(instance.__init__) 17 | instance_keys = list(instance_args.parameters.keys()) 18 | instance_keys.remove("self") 19 | else: 20 | instance_keys = inspect.getfullargspec(instance).args 21 | 22 | if isinstance(config, ConfigHandler) or isinstance(config, AttributeDict): 23 | config_dict = config.get_dict 24 | elif isinstance(config, Dict): 25 | config_dict = config 26 | elif isinstance(config, List): 27 | config_dict = {} 28 | for sub_conf in config: 29 | if isinstance(sub_conf, ConfigHandler) or isinstance(sub_conf, AttributeDict): 30 | config_dict.update(sub_conf.get_dict) 31 | elif isinstance(sub_conf, Dict): 32 | config_dict.update(sub_conf) 33 | else: 34 | raise UserWarning( 35 | f"cinit: Unknown config type. config must be Dict, AttributeDict or ConfigHandler but is {type(config)}") 36 | 37 | init_dict = {} 38 | 39 | for name, arg in kwargs.items(): 40 | if name in instance_keys: 41 | init_dict[name] = arg 42 | 43 | for name, arg in config_dict.items(): 44 | if name in instance_keys and name not in init_dict.keys(): 45 | init_dict[name] = arg 46 | 47 | init_keys = list(init_dict.keys()) 48 | missing_keys = list(set(instance_keys) - set(init_keys)) 49 | if len(missing_keys) > 0: 50 | raise UserWarning(f"cinig: keys missing {missing_keys}") 51 | 52 | return instance(**init_dict) 53 | -------------------------------------------------------------------------------- /prob_transformer/utils/handler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/ProbTransformer/b3d2669d4c290b327f19ee876dee1a0448792a2c/prob_transformer/utils/handler/__init__.py -------------------------------------------------------------------------------- /prob_transformer/utils/handler/base_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import pathlib 4 | 5 | 6 | class Handler(): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def time_stamp(self) -> str: 12 | return datetime.utcnow().strftime('%Y-%m-%d_%H:%M:%S.%f')[:-4] 13 | 14 | def save_mkdir(self, dir): 15 | while not os.path.isdir(dir): 16 | try: 17 | os.mkdir(dir) 18 | except FileExistsError: 19 | pass 20 | 21 | def counting_name(self, dir, file_name, suffix=False): 22 | 23 | dir = pathlib.Path(dir) 24 | counter = 0 25 | split_file_name = file_name.split('.') 26 | if suffix: 27 | counting_file_name = '.'.join(split_file_name[:-1]) + f"-{counter}." + split_file_name[-1] 28 | else: 29 | counting_file_name = file_name + f"-{counter}" 30 | 31 | while os.path.isfile(dir / counting_file_name) or os.path.isdir(dir / counting_file_name): 32 | if suffix: 33 | counting_file_name = '.'.join(split_file_name[:-1]) + f"-{counter}." + split_file_name[-1] 34 | else: 35 | counting_file_name = file_name + f"-{counter}" 36 | counter += 1 37 | 38 | return counting_file_name 39 | 40 | def get_latest_name(self, dir, file_name, suffix=False): 41 | 42 | dir = pathlib.Path(dir) 43 | counter = 0 44 | split_file_name = file_name.split('.') 45 | if suffix: 46 | counting_file_name = '.'.join(split_file_name[:-1]) + f"-{counter}." + split_file_name[-1] 47 | else: 48 | counting_file_name = file_name + f"-{counter}" 49 | 50 | while os.path.isfile(dir / counting_file_name) or os.path.isdir(dir / counting_file_name): 51 | if suffix: 52 | counting_file_name = '.'.join(split_file_name[:-1]) + f"-{counter}." + split_file_name[-1] 53 | else: 54 | counting_file_name = file_name + f"-{counter}" 55 | counter += 1 56 | 57 | if counter == 0: 58 | return counting_file_name 59 | else: 60 | if suffix: 61 | counting_file_name = '.'.join(split_file_name[:-1]) + f"-{counter - 2}." + split_file_name[-1] 62 | else: 63 | counting_file_name = file_name + f"-{counter - 2}" 64 | return counting_file_name 65 | -------------------------------------------------------------------------------- /prob_transformer/utils/handler/checkpoint.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import torch 3 | import numpy as np 4 | import os, time 5 | 6 | """ 7 | save and restore checkpoints including parameters, rng states and env/data states 8 | """ 9 | 10 | 11 | class CheckpointHandler(): 12 | 13 | def __init__(self, checkpoint_dir, ): 14 | self.dir = pathlib.Path(checkpoint_dir) 15 | 16 | def save_training(self, mode_state_dict, optimizer_state_dict, epoch=None, loss=None, number=0): 17 | torch.save({ 18 | 'epoch': epoch, 19 | 'model_state_dict': mode_state_dict, 20 | 'optimizer_state_dict': optimizer_state_dict, 21 | 'loss': loss, 22 | }, self.dir / f"training_{number}.tar") 23 | 24 | def load_training(self, number=0): 25 | checkpoint = torch.load(self.dir / f"training_{number}.tar") 26 | mode_state_dict = checkpoint['model_state_dict'] 27 | optimizer_state_dict = checkpoint['optimizer_state_dict'] 28 | epoch = checkpoint['epoch'] 29 | loss = checkpoint['loss'] 30 | return mode_state_dict, optimizer_state_dict, epoch, loss 31 | 32 | def save_model(self, model, number=0): 33 | torch.save(model, self.dir / f"model_{number}.pth") 34 | 35 | def save_optimizer(self, optimizer, number=0): 36 | torch.save(optimizer, self.dir / f"optimizer_{number}.pth") 37 | 38 | def load_newest_optimizer(self, optimizer, map_location=None): 39 | newest_file = "" 40 | newest_age = 1e24 41 | for file in os.listdir(self.dir): 42 | if file.endswith(".pth") and "optimizer" in file.__str__(): 43 | file_stat = os.stat(self.dir / file) 44 | file_age = (time.time() - file_stat.st_mtime) 45 | if file_age < newest_age: 46 | newest_file = file 47 | newest_age = file_age 48 | print("load optimizer ", newest_file) 49 | optimizer.optimizer = torch.load(self.dir / newest_file, map_location=map_location) 50 | 51 | def load_optimizer(self, optimizer, number=0, map_location=None): 52 | optimizer.optimizer = torch.load(self.dir / f"optimizer_{number}.pth", map_location=map_location) 53 | 54 | def load_model(self, number=0, map_location=None): 55 | model = torch.load(self.dir / f"model_{number}.pth", map_location=map_location) 56 | return model 57 | 58 | def load_newest_model(self, map_location=None): 59 | newest_file = "" 60 | newest_age = 1e24 61 | for file in os.listdir(self.dir): 62 | if file.endswith(".pth") and "model" in file.__str__(): 63 | file_stat = os.stat(self.dir / file) 64 | file_age = (time.time() - file_stat.st_mtime) 65 | if file_age < newest_age: 66 | newest_file = file 67 | newest_age = file_age 68 | print("load file ", newest_file) 69 | model = torch.load(self.dir / newest_file, map_location=map_location) 70 | return model 71 | 72 | def model_exists(self): 73 | return os.path.isfile(self.dir / f"model_0.pth") 74 | 75 | def save_state_dict(self, state_dict, number=0): 76 | torch.save(state_dict, self.dir / f"state_dict_{number}.pth") 77 | 78 | def load_state_dict(self, number=0, cpu=True): 79 | if cpu: 80 | state_dict = torch.load(self.dir / f"state_dict_{number}.pth", map_location=torch.device('cpu')) 81 | else: 82 | state_dict = torch.load(self.dir / f"state_dict_{number}.pth") 83 | return state_dict 84 | 85 | def save_object(self, object, name="object_0"): 86 | np.save(self.dir / f"{name}.npy", object, allow_pickle=True) 87 | 88 | def load_object(self, name="object_0"): 89 | return np.load(self.dir / f"{name}.npy", allow_pickle=True) 90 | -------------------------------------------------------------------------------- /prob_transformer/utils/handler/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import pathlib 4 | 5 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.utils.handler.base_handler import Handler 6 | 7 | """ 8 | reads a yml config or a dict and safes it into experiment folder 9 | """ 10 | 11 | 12 | class AttributeDict(Handler): 13 | def __init__(self, dictionary, name): 14 | super().__init__() 15 | 16 | for key in dictionary: 17 | if isinstance(dictionary[key], dict): 18 | if not hasattr(self, "sub_config"): 19 | self.sub_config = [] 20 | self.sub_config.append(key) 21 | setattr(self, key, AttributeDict(dictionary[key], key)) 22 | else: 23 | setattr(self, key, dictionary[key]) 24 | 25 | def __repr__(self): 26 | return str(self.__dict__) 27 | 28 | def __str__(self): 29 | return str(self.__dict__) 30 | 31 | @property 32 | def get_dict(self): 33 | return self.__dict__ 34 | 35 | @property 36 | def dict(self): 37 | return self.__dict__ 38 | 39 | def set_attr(self, name, value): 40 | if isinstance(value, pathlib.Path): 41 | value = value.as_posix() 42 | self.__setattr__(name, value) 43 | 44 | 45 | class ConfigHandler(AttributeDict): 46 | 47 | def __init__(self, config_file=None, config_dict=None): 48 | 49 | if config_file is None and config_dict is None: 50 | raise UserWarning("ConfigHandler: config_file and config_dict is None") 51 | 52 | elif config_file is not None and config_dict is None: 53 | with open(config_file, 'r') as f: 54 | config_dict = yaml.load(f, Loader=yaml.Loader) 55 | 56 | super().__init__(config_dict, "main") 57 | 58 | self.check_experiment_config() 59 | 60 | def check_experiment_config(self): 61 | if hasattr(self, "expt"): 62 | for attr_name in ['experiment_name']: 63 | if not hasattr(self.expt, attr_name): 64 | raise UserWarning(f"ConfigHandler: {attr_name} is missing") 65 | elif isinstance(self.expt.__getattribute__(attr_name), str): 66 | self.expt.__setattr__(attr_name, str(self.expt.__getattribute__(attr_name))) 67 | 68 | def save_config(self, dir, file_name="config.yml"): 69 | dir = pathlib.Path(dir) 70 | self.save_mkdir(dir) 71 | if os.path.isfile(dir / file_name): 72 | file_name = self.counting_name(dir, file_name, suffix=True) 73 | with open(dir / file_name, 'w+') as f: 74 | config_dict = self.get_dict 75 | yaml.dump(config_dict, f, default_flow_style=False, encoding='utf-8') 76 | return dir / file_name 77 | -------------------------------------------------------------------------------- /prob_transformer/utils/handler/folder.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.utils.handler.base_handler import Handler 4 | 5 | """ 6 | Handle the location, new folders and experiments sub-folder structure. 7 | 8 | base_dir / project / session / experiment 9 | 10 | experiment will be increased 11 | 12 | """ 13 | 14 | 15 | class FolderHandler(Handler): 16 | 17 | def __init__(self, experiments_dir, session_name=None, experiment_name=None, count_expt=False, 18 | reload_expt=False): 19 | super().__init__() 20 | 21 | self.experiments_dir = pathlib.Path(experiments_dir) 22 | 23 | self.session_name = session_name 24 | self.experiment_name = experiment_name 25 | self.count_expt = count_expt 26 | self.reload_expt = reload_expt 27 | 28 | self.expt_dir = self.create_folder() 29 | 30 | def create_folder(self): 31 | 32 | dir = self.experiments_dir 33 | self.save_mkdir(dir) 34 | 35 | dir = dir / self.session_name 36 | self.save_mkdir(dir) 37 | 38 | if self.reload_expt: 39 | self.experiment_name = self.get_latest_name(dir, self.experiment_name) 40 | elif self.count_expt: 41 | self.experiment_name = self.counting_name(dir, self.experiment_name) 42 | 43 | dir = dir / self.experiment_name 44 | self.save_mkdir(dir) 45 | 46 | return dir 47 | 48 | @property 49 | def dir(self): 50 | return self.expt_dir 51 | -------------------------------------------------------------------------------- /prob_transformer/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | import numpy as np 4 | import torch 5 | import time 6 | import datetime 7 | import json 8 | 9 | from prob_transformer.utils.handler.base_handler import Handler 10 | 11 | 12 | class Logger(Handler): 13 | """ 14 | call or .log() stores log massages in form (key, value, step) in json and prints them with timestamp. 15 | """ 16 | 17 | def __init__(self, log_dir, file_name="log_file.txt", json_name="json_log.json"): 18 | super().__init__() 19 | 20 | self.log_dir = Path(log_dir) 21 | self.log_file = self.log_dir / file_name 22 | self.json_file = self.log_dir / json_name 23 | 24 | self.timer_dict = {} 25 | self.line_dicts = [] 26 | 27 | self.start_log() 28 | 29 | def __call__(self, key, value=None, time_step=None, rank=0): 30 | self.log(key, value, time_step, rank) 31 | 32 | def log(self, key, value=None, time_step=None, rank=0): 33 | if rank != 0 and rank != 'cpu': 34 | return 35 | 36 | if isinstance(value, torch.Tensor): 37 | value = value.cpu().detach().numpy() 38 | 39 | if isinstance(value, np.ndarray): 40 | value = value.tolist() 41 | 42 | if isinstance(time_step, torch.Tensor): 43 | time_step = time_step.cpu().detach().numpy() 44 | 45 | time_stamp = self.time_stamp() 46 | 47 | dump_dict = {"t": time_stamp} 48 | 49 | if value is None: 50 | if time_step is None: 51 | string = f"{time_stamp} {key}" 52 | dump_dict["k"] = str(key) 53 | else: 54 | string = f"{time_stamp} {key} step:{time_step}" 55 | dump_dict["k"] = str(key) 56 | dump_dict["s"] = str(time_step) 57 | 58 | else: 59 | if isinstance(value, int): 60 | if value > 999: 61 | value = f"{value:,}" 62 | if time_step is None: 63 | string = f"{time_stamp} {key}: {value}" 64 | dump_dict["k"] = str(key) 65 | dump_dict["v"] = str(value) 66 | else: 67 | string = f"{time_stamp} {key}: {value} step:{time_step}" 68 | dump_dict["k"] = str(key) 69 | dump_dict["v"] = str(value) 70 | dump_dict["s"] = str(time_step) 71 | 72 | print(string) 73 | with open(self.log_file, 'a') as file: 74 | file.write(f"{string} \n") 75 | 76 | self.line_dicts.append(dump_dict) 77 | 78 | def start_log(self): 79 | if os.path.isfile(self.log_file) and os.access(self.log_file, os.R_OK): 80 | self.log("LOGGER: continue logging") 81 | else: 82 | with open(self.log_file, 'w+') as file: 83 | file.write( 84 | f"{self.time_stamp()} LOGGER: start logging with Python version: {str(sys.version).split('(')[0]} \n") 85 | 86 | def print_config(self, config, name="main"): 87 | if name == "main": 88 | self.log("#" * 20 + " CONFIG:") 89 | else: 90 | self.log(f"sub config {name:8}", 91 | np.unique([f"{attr} : {str(value)} " for attr, value in config.get_dict.items()]).tolist()) 92 | 93 | if hasattr(config, "sub_config"): 94 | for cfg in config.sub_config: 95 | self.print_config(getattr(config, cfg), cfg) 96 | 97 | def start_timer(self, name, rank=0): 98 | name = f"{name}_{str(rank)}" 99 | self.timer_dict[name] = time.time() 100 | 101 | def timer(self, name, time_step=None, rank=0): 102 | name = f"{name}_{str(rank)}" 103 | if name not in self.timer_dict.keys(): 104 | self.log("!!!!!! UNKNOWN TIMER", name, time_step) 105 | else: 106 | duration = time.time() - self.timer_dict[name] 107 | self.log(f"timer {name.split('_')[0]}", str(datetime.timedelta(seconds=duration)), time_step, rank) 108 | 109 | def save_to_json(self, rank=0): 110 | if rank != 0: 111 | return 112 | with open(self.json_file, 'w') as file: 113 | json.dump(self.line_dicts, file) 114 | self.log("LOGGER: save log to json") 115 | -------------------------------------------------------------------------------- /prob_transformer/utils/summary.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class SummaryDict(): 7 | """ 8 | Similar to TensorFlow summary but can deal with lists, stores everything in numpy arrays. Please see main for usage. 9 | """ 10 | 11 | def __init__(self, summary=None): 12 | self.summary = {} 13 | if summary is not None: 14 | for key, value in summary.items(): 15 | self.summary[key] = value 16 | 17 | @property 18 | def keys(self): 19 | keys = list(self.summary.keys()) 20 | if "step" in keys: 21 | keys.remove("step") 22 | return keys 23 | 24 | def __call__(self, summary): 25 | 26 | if isinstance(summary, SummaryDict): 27 | for key, value_lists in summary.summary.items(): 28 | if key in self.summary.keys(): 29 | if key == "step": 30 | if min(value_lists) != 1 + max(self.summary['step']): 31 | value_lists = np.asarray(value_lists) + max(self.summary['step']) + 1 - min(value_lists) 32 | self.summary[key] = np.concatenate([self.summary[key], value_lists], axis=0) 33 | else: 34 | self.summary[key] = value_lists 35 | 36 | elif isinstance(summary, Dict): 37 | for name, value in summary.items(): 38 | self.__setitem__(name, value) 39 | elif isinstance(summary, (Tuple, List)): 40 | for l in summary: 41 | self.__call__(l) 42 | else: 43 | raise UserWarning(f"SummaryDict: call not implementet for type: {type(summary)}") 44 | 45 | def __setitem__(self, key, item): 46 | 47 | if isinstance(item, torch.Tensor): 48 | item = item.detach().cpu().numpy() 49 | if isinstance(item, List): 50 | if isinstance(item[0], torch.Tensor): 51 | item = [v.cpu().numpy() for v in item] 52 | if isinstance(item, np.ndarray): 53 | item = np.squeeze(item) 54 | 55 | item = np.expand_dims(np.asarray(item), axis=0) 56 | if item.shape.__len__() < 2: 57 | item = np.expand_dims(item, axis=0) 58 | 59 | if key not in self.summary.keys(): 60 | self.summary[key] = item 61 | else: 62 | self.summary[key] = np.concatenate([self.summary[key], item], axis=0) 63 | 64 | def __getitem__(self, key): 65 | return self.summary[key] 66 | 67 | def save(self, file): 68 | np.save(file, self.summary) 69 | 70 | def load(self, file): 71 | return np.load(file).tolist() 72 | -------------------------------------------------------------------------------- /prob_transformer/utils/supporter.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from prob_transformer.utils.handler.config import ConfigHandler 3 | from prob_transformer.utils.logger import Logger 4 | from prob_transformer.utils.handler.folder import FolderHandler 5 | from prob_transformer.utils.handler.checkpoint import CheckpointHandler 6 | 7 | 8 | class Supporter(): 9 | 10 | def __init__(self, experiments_dir=None, config_dir=None, config_dict=None, count_expt=True): 11 | 12 | self.cfg = ConfigHandler(config_dir, config_dict) 13 | 14 | if experiments_dir is None and self.cfg.expt.experiments_dir is None: 15 | raise UserWarning("ConfigHandler: experiment_dir and config.expt.experiment_dir is None") 16 | elif experiments_dir is not None: 17 | self.cfg.expt.set_attr("experiments_dir", experiments_dir) 18 | else: 19 | experiments_dir = pathlib.Path(self.cfg.expt.experiments_dir) 20 | 21 | session_name = f"{self.cfg.data.type}-experiments" 22 | 23 | self.folder = FolderHandler(experiments_dir, session_name, self.cfg.expt.experiment_name, count_expt) 24 | self.cfg.expt.experiment_name = self.folder.experiment_name 25 | self.cfg.expt.experiment_dir = self.folder.dir 26 | self.cfg.save_config(self.folder.dir) 27 | 28 | self.logger = Logger(self.folder.dir) 29 | self.ckp = CheckpointHandler(self.folder.dir) 30 | 31 | self.logger.log("session_name", session_name) 32 | self.logger.log("experiment_name", self.cfg.expt.experiment_name) 33 | 34 | def get_logger(self): 35 | return self.logger 36 | 37 | def get_config(self): 38 | return self.cfg 39 | 40 | def get_checkpoint_handler(self): 41 | return self.ckp 42 | -------------------------------------------------------------------------------- /prob_transformer/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def count_parameters(parameters): 3 | return sum(p.numel() for p in parameters if p.requires_grad) 4 | -------------------------------------------------------------------------------- /probtransformer.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import wget 3 | from tqdm import tqdm 4 | from collections import defaultdict 5 | import numpy as np 6 | import torch 7 | import distance 8 | import sys 9 | 10 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.utils.config_init import cinit 11 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.utils.handler.config import ConfigHandler 12 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.model.probtransformer import ProbTransformer 13 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.data.rna_handler import RNAHandler 14 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.data.iterator import MyIterator 15 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.evaluation.statistics_center import StatisticsCenter 16 | from RnaBench.lib.rna_folding_algorithms.DL.ProbTransformer.prob_transformer.routine.evaluation import is_valid_structure,correct_invalid_structure, struct_to_mat 17 | 18 | rng = np.random.default_rng(seed=0) 19 | 20 | NUCS = { 21 | 'T': 'U', 22 | 'P': 'U', 23 | 'R': 'A', # or 'G' 24 | 'Y': 'C', # or 'T' 25 | 'M': 'C', # or 'A' 26 | 'K': 'U', # or 'G' 27 | 'S': 'C', # or 'G' 28 | 'W': 'U', # or 'A' 29 | 'H': 'C', # or 'A' or 'U' 30 | 'B': 'U', # or 'G' or 'C' 31 | 'V': 'C', # or 'G' or 'A' 32 | 'D': 'A', # or 'G' or 'U' 33 | 'N': rng.choice(['A', 'C', 'G', 'U']), # 'N', 34 | 'A': 'A', 35 | 'U': 'U', 36 | 'C': 'C', 37 | 'G': 'G', 38 | } 39 | 40 | class ProbabilisticTransformer(): 41 | def __init__(self): 42 | here = os.path.dirname(os.path.abspath(__file__)) 43 | sys.path.append(here) 44 | 45 | self.rank = 'cpu' 46 | 47 | model_path="RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/prob_transformer_final.pth" 48 | cnn_head_path="RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/cnn_head_final.pth" 49 | # rna_data="RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/data/rna_data.plk" 50 | 51 | 52 | if cnn_head_path == "RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/cnn_head_final.pth" and not os.path.exists("RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/cnn_head_final.pth"): 53 | os.makedirs("checkpoints", exist_ok=True) 54 | print("Download CNN head checkpoint") 55 | wget.download("https://ml.informatik.uni-freiburg.de/research-artifacts/probtransformer/cnn_head_final.pth", "RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/cnn_head_final.pth") 56 | 57 | if model_path == "RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/prob_transformer_final.pth" and not os.path.exists("RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/prob_transformer_final.pth"): 58 | os.makedirs("checkpoints", exist_ok=True) 59 | print("Download prob transformer checkpoint") 60 | wget.download("https://ml.informatik.uni-freiburg.de/research-artifacts/probtransformer/prob_transformer_final.pth", "RnaBench/lib/rna_folding_algorithms/DL/ProbTransformer/checkpoints/prob_transformer_final.pth") 61 | 62 | 63 | transformer_checkpoint = torch.load(model_path, map_location=torch.device(self.rank)) 64 | 65 | 66 | cfg = ConfigHandler(config_dict=transformer_checkpoint['config']) 67 | 68 | # self.rna_data = cinit(RNAHandler, cfg.data.rna.dict, df_path=rna_data, sub_set='valid', prob_training=True, 69 | # device=self.rank, seed=cfg.data.seed, ignore_index=-1, similarity='80', exclude=[], max_length=500) 70 | 71 | self.seq_vocab = ['A', 'C', 'G', 'U', 'N'] 72 | self.struct_vocab = ['.', '(0c', ')0c', '(1c', ')1c', '(2c', ')2c', '(0nc', ')0nc', '(1nc', ')1nc', '(2nc', 73 | ')2nc'] 74 | 75 | self.seq_stoi = dict(zip(self.seq_vocab, range(len(self.seq_vocab)))) 76 | self.seq_itos = dict((y, x) for x, y in self.seq_stoi.items()) 77 | 78 | for nuc, mapping in NUCS.items(): 79 | self.seq_stoi[nuc] = self.seq_stoi[mapping] 80 | 81 | self.struct_itos = dict(zip(range(len(self.struct_vocab)), self.struct_vocab)) 82 | self.struct_stoi = dict((y, x) for x, y in self.struct_itos.items()) 83 | 84 | self.seq_vocab_size = len(self.seq_vocab) 85 | self.struct_vocab_size = len(self.struct_vocab) 86 | 87 | self.model = cinit(ProbTransformer, cfg.model.dict, seq_vocab_size=self.seq_vocab_size, trg_vocab_size=self.struct_vocab_size, 88 | mat_config=None, mat_head=False, mat_input=False, prob_ff=False, 89 | scaffold=False, props=False).to(self.rank) 90 | 91 | self.model.load_state_dict(transformer_checkpoint['state_dict'], strict=False) 92 | 93 | self.cnn_head = torch.load(cnn_head_path, map_location=torch.device(self.rank)) 94 | 95 | def __name__(self): 96 | return 'ProbTransformer' 97 | 98 | def __repr__(self): 99 | return 'ProbTransformer' 100 | 101 | def __call__(self, sequence, id=0): 102 | 103 | self.cnn_head.eval() 104 | self.model.eval() 105 | 106 | if sequence is not None: 107 | # print(sequence) 108 | # print(len(sequence)) 109 | # print(f"Fold input sequence {sequence}") 110 | if sorted(set(sequence)) != ['A', 'C', 'G', 'U']: 111 | sequence = ''.join([NUCS[x] for x in sequence]) 112 | # raise UserWarning(f"unknown symbols in sequence: {set(sequence).difference('A', 'C', 'G', 'U')}. Please only use ACGU") 113 | 114 | src_seq = torch.LongTensor([[self.seq_stoi[s] for s in sequence]]).to(self.rank) 115 | src_len = torch.LongTensor([src_seq.shape[1]]).to(self.rank) 116 | 117 | raw_output, raw_latent = self.model(src_seq, src_len, infer_mean=True, output_latent=True) 118 | 119 | pred_dist = torch.nn.functional.one_hot(torch.argmax(raw_output, dim=-1), 120 | num_classes=raw_output.shape[-1]).to(torch.float).detach() 121 | pred_token = torch.argmax(raw_output, dim=-1).detach() 122 | 123 | b_pred_mat, mask = self.cnn_head(latent=raw_latent, src=src_seq, pred=pred_token, src_len=src_len) 124 | 125 | pred_dist = pred_dist[0, :, :].detach().cpu() 126 | pred_argmax = torch.argmax(pred_dist, keepdim=False, dim=-1).numpy().tolist() 127 | pred_struct = [self.struct_itos[i] for i in pred_argmax] 128 | # print("Predicted structure without CNN head:", pred_struct) 129 | if not is_valid_structure(pred_struct): 130 | pred_struct = correct_invalid_structure(pred_struct, pred_dist, self.struct_stoi, src_seq.shape[1]) 131 | print("correction pred_struct", pred_struct) 132 | 133 | pred_mat = torch.sigmoid(b_pred_mat[0, :, :, 1]) 134 | pred_mat = torch.triu(pred_mat, diagonal=1).t() + torch.triu(pred_mat, diagonal=1) 135 | bindings_idx = np.where(pred_mat.cpu().detach().numpy() > 0.5) 136 | # print("Predicted binding from CNN head, open :", bindings_idx[0].tolist()) 137 | # print("Predicted binding from CNN head, close:", bindings_idx[1].tolist()) 138 | # print(max(bindings_idx[0])) 139 | return [[o, c, 0] for o, c in zip(bindings_idx[0].tolist(), bindings_idx[1].tolist())] # add 0 to pairlist to be able to compute metrics 140 | 141 | 142 | 143 | 144 | 145 | 146 | # i __name__ == "__main__": 147 | # def eval_probtransformer(sequence): 148 | # parser = argparse.ArgumentParser(description='Using the ProbTransformer to fold an RNA sequence.') 149 | # parser.add_argument('-s', '--sequence', type=str, help='A RNA sequence as ACGU-string') 150 | # parser.add_argument('-m', '--model', default="checkpoints/prob_transformer_final.pth", type=str, 151 | # help='A checkpoint file for the model to use') 152 | # parser.add_argument('-c', '--cnn_head', default="checkpoints/cnn_head_final.pth", type=str, 153 | # help='A RNA sequence as ACGU-string') 154 | # parser.add_argument('-e', '--evaluate', action='store_true', help='Evaluates model on the test set TS0') 155 | # parser.add_argument('-d', '--rna_data', default="data/rna_data.plk", type=str, help='Path to rna dataframe') 156 | # parser.add_argument('-t', '--test_data', default="data/TS0.plk", type=str, help='Path to test dataframe TS0') 157 | # parser.add_argument('-r', '--rank', default="cuda", type=str, help='Device to infer the model, cuda or cpu') 158 | # args = parser.parse_args() 159 | 160 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup(name='prob-transformer', 6 | description="Code to NeurIPS 2022 Paper 'Probabilistic Transformer: Modelling Ambiguities and Distributions for RNA Folding and Molecule Design'", 7 | author='Joerg Franke', 8 | url='https://github.com/automl/probtransformer', 9 | version='1.0', 10 | dependency_links=[], 11 | packages = ['prob_transformer'] 12 | ) 13 | --------------------------------------------------------------------------------