├── .github └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── ProtMamba_ssm ├── __init__.py ├── dataloaders.py ├── fim.py ├── modules.py ├── trainer.py └── utils.py ├── README.html ├── README.md ├── README_files └── libs │ ├── bootstrap │ ├── bootstrap-icons.css │ ├── bootstrap-icons.woff │ ├── bootstrap.min.css │ └── bootstrap.min.js │ ├── clipboard │ └── clipboard.min.js │ └── quarto-html │ ├── anchor.min.js │ ├── popper.min.js │ ├── quarto-syntax-highlighting.css │ ├── quarto.js │ ├── tippy.css │ └── tippy.umd.min.js ├── configs ├── default_config.yaml └── default_config_fim-finetuning.yaml ├── data ├── cluster_testing_set.txt ├── cluster_training_set.txt ├── example_msa.a3m ├── example_pretrained_model │ ├── config.json │ ├── optimizer.pt │ ├── pytorch_model.bin │ ├── rng_state.pth │ ├── scheduler.pt │ └── trainer_state.json └── logo.jpeg ├── requirements.txt ├── settings.ini ├── setup.py ├── tests ├── generation │ ├── figures_sampled_sequences.ipynb │ ├── sample_sequences.ipynb │ ├── score_sequences.ipynb │ └── test_model.ipynb ├── proteingym │ ├── mmseqs_search.py │ ├── nb_proteingym.ipynb │ └── proteingym.py └── utils.py └── train.py /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | 3 | permissions: 4 | contents: write 5 | pages: write 6 | 7 | on: 8 | push: 9 | branches: [ "main", "master" ] 10 | workflow_dispatch: 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: [uses: fastai/workflows/quarto-ghp@master] 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: [uses: fastai/workflows/nbdev-ci@master] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _docs/ 2 | _proc/ 3 | 4 | *.bak 5 | .gitattributes 6 | .last_checked 7 | .gitconfig 8 | *.bak 9 | *.log 10 | *~ 11 | ~* 12 | _tmp* 13 | tmp* 14 | tags 15 | *.pkg 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # dotenv 99 | .env 100 | 101 | # virtualenv 102 | .venv 103 | venv/ 104 | ENV/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | .vscode 120 | *.swp 121 | 122 | # osx generated files 123 | .DS_Store 124 | .DS_Store? 125 | .Trashes 126 | ehthumbs.db 127 | Thumbs.db 128 | .idea 129 | 130 | # pytest 131 | .pytest_cache 132 | 133 | # tools/trust-doc-nbs 134 | docs_src/.last_checked 135 | 136 | # symlinks to fastai 137 | docs_src/fastai 138 | tools/fastai 139 | 140 | # link checker 141 | checklink/cookies.txt 142 | 143 | # .gitconfig is now autogenerated 144 | .gitconfig 145 | 146 | # Quarto installer 147 | .deb 148 | .pkg 149 | 150 | # Quarto 151 | .quarto 152 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Using an official Python runtime with CUDA support as a parent image (https://hub.docker.com/r/nvidia/cuda/) 2 | FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel as base 3 | 4 | ENV HOST docker 5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 6 | ENV TZ Europe/Paris 7 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 8 | 9 | # TO CHANGE IF NEEDED 10 | RUN useradd -m -U -s /bin/bash user 11 | RUN apt-get update && apt-get install -y git 12 | 13 | # Set the working directory: TO CHANGE IF NEEDED 14 | USER user 15 | 16 | # Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for 17 | ENV PIP_NO_CACHE_DIR=1 18 | #install from requirements.txt 19 | COPY requirements.txt /home/user/requirements.txt 20 | 21 | #apt get git 22 | RUN pip install --no-cache-dir -r /home/user/requirements.txt 23 | 24 | WORKDIR /home/user 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022, fastai 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /ProtMamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | from .dataloaders import * 3 | from .fim import * 4 | from .modules import * 5 | from .trainer import * 6 | from .utils import * 7 | -------------------------------------------------------------------------------- /ProtMamba_ssm/dataloaders.py: -------------------------------------------------------------------------------- 1 | from .utils import AA_TO_ID 2 | from .fim import NoFIM, SingleSpanFIM, MultipleSpanFIM 3 | import pickle 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import DataLoader 8 | from dataclasses import dataclass 9 | from typing import Dict, Sequence 10 | 11 | # Make dataset 12 | class Uniclust30_Dataset(Dataset): 13 | """ 14 | Dataset class used to import the Uniclust30 folders. 15 | If `filename` = "encoded_MSAs.pkl", it will load the full dataset. 16 | If `filename` = "encoded_MSAs_subset.pkl", it will load a small subset of the dataset. 17 | If `sample` = True, it will sample a random number of sequences from each cluster. 18 | If `sample` = False, it will load all the sequences from each cluster (and shuffle them). 19 | To limit the length of the MSAs, set `max_msa_len` to a positive integer. 20 | If `reverse` = True, it will reverse the sequences with probability 0.5 and move the last token to the front. 21 | If `scrambling_strategy` = "no-scramble", it will not scramble the sequences and simply concatenate them. 22 | If `scrambling_strategy` = "OpenAI", it will scramble the sequences using the OpenAI strategy. 23 | If `scrambling_strategy` = "inpaint", it will scramble the sequences using the inpaint strategy. In this case it will use 24 | `max_patches` patches and mask `mask_fraction` of the patches. 25 | """ 26 | _FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM} 27 | _POSIDS = {"none", "1d", "2d"} 28 | 29 | def __init__(self, filename="encoded_MSAs_train.pkl", 30 | filepath="/nvme1/common/OpenProteinSet/", 31 | sample=False, 32 | max_msa_len=-1, 33 | reverse=False, 34 | seed=42, 35 | troubleshoot=False, 36 | fim_strategy="no-scramble", 37 | max_patches=5, 38 | mask_fraction=0.2, 39 | always_mask=False, 40 | max_position_embeddings=2048, 41 | max_seq_position_embeddings=512, 42 | add_position_ids="none", ): 43 | np.random.seed(seed) 44 | self.path = filepath 45 | # self.path_clusters = self.path + "OpenProteinSet_uniclust30-filtered/" 46 | if filename: 47 | self.dataset = pickle.load(open(self.path + filename, "rb")) 48 | self.cluster_names = list(self.dataset.keys()) 49 | else: 50 | self.dataset = None 51 | self.cluster_names = [] 52 | self.sample = sample 53 | self.max_msa_len = max_msa_len 54 | self.reverse = reverse 55 | self.fim_strategy = fim_strategy 56 | if fim_strategy in Uniclust30_Dataset._FIM: 57 | self.fim = Uniclust30_Dataset._FIM[fim_strategy](max_patches=max_patches, 58 | mask_fraction=mask_fraction, 59 | always_mask=always_mask, 60 | add_position_ids=add_position_ids != "none", 61 | troubleshoot=troubleshoot) 62 | else: 63 | raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.') 64 | self.max_position_embeddings = max_position_embeddings 65 | self.max_seq_position_embeddings = max_seq_position_embeddings 66 | self.add_position_ids = add_position_ids 67 | 68 | self.troubleshoot = troubleshoot 69 | 70 | def __len__(self): 71 | return len(self.cluster_names) 72 | 73 | def __getitem__(self, idx, shuffle=True): 74 | # get all the sequences in the cluster 75 | sequences = self.get_sequences(idx) 76 | # get total number of sequences in the cluster and choose how many to sample 77 | orig_num_sequences = len(self.get_index_start_of_sequences(sequences)) 78 | num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences 79 | # sample the sequences 80 | sequences, position_ids = self.sample_sequences(sequences, num_sequences, shuffle=shuffle) 81 | # with probability 0.5, reverse the sequences and move the last token to the front 82 | sequences, position_ids = self.reverse_sequences(sequences, position_ids) if ( 83 | self.reverse and np.random.rand() > 0.5) else sequences, position_ids 84 | # limit the length of the MSA 85 | sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences 86 | if self.add_position_ids != "none": 87 | position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids 88 | # convert to tensor 89 | sequences = torch.asarray(sequences, dtype=torch.int64) 90 | position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0, 91 | self.max_position_embeddings - 1) if self.add_position_ids!="none" else None 92 | 93 | if self.troubleshoot: 94 | print( 95 | f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}") 96 | if self.add_position_ids == "1d": 97 | return dict(input_ids=sequences, position_ids=position_ids, labels=sequences) 98 | if self.add_position_ids == "2d": 99 | seq_position_ids = (sequences == AA_TO_ID[""]).int().cumsum(-1).clamp(0, 100 | self.max_seq_position_embeddings - 1).contiguous() 101 | return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids, 102 | labels=sequences) 103 | return dict(input_ids=sequences, labels=sequences) 104 | 105 | def get_sequences(self, idx): 106 | """Get the sequences in the cluster with index `idx`.""" 107 | cluster_name = self.cluster_names[idx] 108 | sequences = self.dataset[cluster_name] 109 | return sequences 110 | 111 | def get_index_start_of_sequences(self, sequences): 112 | """Get the positions of the start of each sequence in the cluster.""" 113 | return np.where(sequences == 0)[0] 114 | 115 | def reverse_sequences(self, sequence, position_ids=None): 116 | """Reverse the sequences and move the last token to the front.""" 117 | sequence = sequence[::-1] 118 | if position_ids is not None: 119 | position_ids = position_ids[::-1] 120 | return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate( 121 | [position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None 122 | 123 | def sample_sequences(self, sequences, num_sequences, shuffle=True, which_seqs=None): 124 | """Sample `num_sequences` from the sequences in the cluster.""" 125 | L = len(sequences) 126 | # get the indexes of the start of each sequence 127 | inds = self.get_index_start_of_sequences(sequences) 128 | # check that there are sequences in the cluster and that there are enough of them 129 | assert len(inds) > 0, "No sequences found in cluster." 130 | assert len(inds) >= num_sequences, "Not enough sequences in cluster." 131 | # sample n_sequences randomly from the sequences 132 | if which_seqs is None: 133 | if shuffle: 134 | which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False) 135 | else: 136 | which_seqs = np.arange(len(inds))[-num_sequences:] 137 | # get the tuples of start and end indexes of the sequences 138 | tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs] 139 | if self.troubleshoot: 140 | print(f"Sampled sequences: {tuples}") 141 | # concatenate the sequences 142 | sequences, position_ids = self.fim.apply(sequences, tuples) 143 | return sequences, position_ids 144 | 145 | 146 | def make_dataloader(dataset): 147 | """Basic function to make a dataloader. 148 | """ 149 | dataloader = DataLoader(dataset) 150 | 151 | @dataclass 152 | class DataCollatorForUniclust30Dataset(object): 153 | """ 154 | Collate examples into a batch, and pad batch to the maximum sequence length. 155 | """ 156 | 157 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 158 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids")) 159 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=AA_TO_ID[""]) 160 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) 161 | if "seq_position_ids" in instances[0] and "position_ids" in instances[0]: 162 | position_ids = torch.nn.utils.rnn.pad_sequence( 163 | [instance["position_ids"] for instance in instances], 164 | batch_first=True, padding_value=0) 165 | seq_position_ids = torch.nn.utils.rnn.pad_sequence( 166 | [instance["seq_position_ids"] for instance in instances], 167 | batch_first=True, padding_value=0) 168 | return dict( 169 | input_ids=input_ids, 170 | labels=labels, 171 | position_ids=position_ids, 172 | seq_position_ids=seq_position_ids, 173 | attention_mask=input_ids.ne(AA_TO_ID[""]), 174 | ) 175 | 176 | if "position_ids" in instances[0]: 177 | position_ids = torch.nn.utils.rnn.pad_sequence( 178 | [instance["position_ids"] for instance in instances], 179 | batch_first=True, padding_value=0) 180 | return dict( 181 | input_ids=input_ids, 182 | labels=labels, 183 | position_ids=position_ids, 184 | attention_mask=input_ids.ne(AA_TO_ID[""]), 185 | ) 186 | 187 | return dict( 188 | input_ids=input_ids, 189 | labels=labels, 190 | attention_mask=input_ids.ne(AA_TO_ID[""]), 191 | ) 192 | -------------------------------------------------------------------------------- /ProtMamba_ssm/fim.py: -------------------------------------------------------------------------------- 1 | from .utils import MASK_TO_ID, AA_TO_ID 2 | import numpy as np 3 | 4 | class AbstractFIM(object): 5 | def __init__(self, 6 | max_patches=5, 7 | mask_fraction=0.2, 8 | always_mask=False, 9 | mask_tokens=MASK_TO_ID, 10 | eos_token=AA_TO_ID[""], 11 | add_position_ids=False, 12 | troubleshoot=False): 13 | """ 14 | This class is designed to concatenate sequences based on different scrambling strategies. 15 | It takes a list of sequences, tuples indicating the start and end indices of each sequence, 16 | an optional number of patches to sample, and a scrambling strategy as inputs. 17 | """ 18 | self.troubleshoot = troubleshoot 19 | self.max_patches = max_patches 20 | self.mask_fraction = mask_fraction 21 | self.mask_tokens = mask_tokens 22 | assert len( 23 | self.mask_tokens) >= self.max_patches, "Number of mask tokens must be bigger than max number of patches." 24 | self.eos_token = eos_token 25 | self.add_position_ids = add_position_ids 26 | self.always_mask = always_mask 27 | 28 | def apply(self, sequences, tuples): 29 | """ 30 | This function concatenates the sequences scrambling each one according to the scrambling strategy. 31 | """ 32 | input_ids, position_ids = [], [] 33 | for t in tuples: 34 | seq, pos = self.fim(sequences, t) 35 | input_ids.extend(seq) 36 | if self.add_position_ids: 37 | position_ids.extend(pos) 38 | if self.add_position_ids: 39 | return input_ids, position_ids 40 | return input_ids, None 41 | 42 | def fim(self, sequences, t): 43 | """ 44 | This function concatenates the sequence's parts based on the scrambling strategy. 45 | """ 46 | raise NotImplementedError 47 | 48 | 49 | class NoFIM(AbstractFIM): 50 | def __init__(self, 51 | max_patches=5, 52 | mask_fraction=0.2, 53 | always_mask=False, 54 | mask_tokens=MASK_TO_ID, 55 | eos_token=AA_TO_ID[""], 56 | add_position_ids=False, 57 | troubleshoot=False): 58 | super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot) 59 | 60 | def fim(self, sequences, t): 61 | """ 62 | This function keeps the sequence identical without any scrambling. 63 | """ 64 | if self.add_position_ids: 65 | position_ids = np.arange(t[0], t[1]) - t[0] 66 | return sequences[t[0]:t[1]], position_ids 67 | return sequences[t[0]:t[1]], None 68 | 69 | 70 | class SingleSpanFIM(AbstractFIM): 71 | 72 | def __init__(self, 73 | max_patches=5, 74 | mask_fraction=0.2, 75 | always_mask=False, 76 | mask_tokens=MASK_TO_ID, 77 | eos_token=AA_TO_ID[""], 78 | add_position_ids=False, 79 | troubleshoot=False): 80 | super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot) 81 | 82 | def fim(self, sequences, t): 83 | """ 84 | This function creates and concatenates parts of the sequences based on the OpenAI scrambling strategy. 85 | It randomly selects two indices within the range of the given tuple, 86 | splits the sequence into three parts based on these indices, and then concatenates them with the 87 | masked patch at the end 88 | """ 89 | new_tuple = tuple(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 2, replace=False))) 90 | part1 = sequences[t[0]:new_tuple[0]] 91 | part2 = sequences[new_tuple[0]:new_tuple[1]] 92 | part3 = sequences[new_tuple[1]:t[1]] 93 | sequence = np.concatenate([part1, [self.mask_tokens[""]], part3, [self.mask_tokens[""]], part2]) 94 | position_ids_sequence = None 95 | if self.add_position_ids: 96 | position_ids = np.arange(t[0], t[1]) - t[0] 97 | position_ids_part1 = position_ids[t[0]:new_tuple[0]] 98 | position_ids_part2 = position_ids[new_tuple[0]:new_tuple[1]] 99 | position_ids_part3 = position_ids[new_tuple[1]:t[1]] 100 | position_ids_sequence = np.concatenate( 101 | [position_ids_part1, [position_ids_part2[0]], position_ids_part3, [position_ids_part2[0]], 102 | position_ids_part2]) 103 | 104 | return sequence, position_ids_sequence 105 | 106 | 107 | class MultipleSpanFIM(AbstractFIM): 108 | def __init__(self, 109 | max_patches=5, 110 | mask_fraction=0.2, 111 | always_mask=False, 112 | mask_tokens=MASK_TO_ID, 113 | eos_token=AA_TO_ID[""], 114 | add_position_ids=False, 115 | troubleshoot=False): 116 | super().__init__(max_patches, mask_fraction, always_mask, mask_tokens, eos_token, add_position_ids, troubleshoot) 117 | 118 | def fim(self, sequences, t): 119 | """ 120 | This function creates and concatenates parts of the sequences based on the inpaint scrambling strategy. 121 | It randomly selects `2*num_patches` indices within the range of the given tuple, 122 | splits the sequence into unmasked and masked parts based on these indices, and then concatenates them. 123 | The number of patches is sampled from a poisson distribution with upper limit `self.max_patches` and average 1. 124 | The concatenation is done by joining all unmaksed parts (interleaved with mask tokens) and afterwards 125 | all masked parts (interleaved with mask tokens). At the end of the unmasked parts, a special token is added 126 | to indicate the end of the unmasked parts, and at the end of the masked parts, a special token is added 127 | to indicate the end of the masked parts. 128 | """ 129 | # sample num_patches from a discrete poisson distribution with upper limit L 130 | def sample_lengths(start, end): 131 | """ 132 | Sample a length uniformly from 1 to max_L*self.mask_fraction (must be bigger than 1). 133 | If the length is larger than max_L, return max_L. 134 | """ 135 | max_L = end - start 136 | length = np.random.randint(1, max(int(max_L * self.mask_fraction), 2)) 137 | return min(length, max_L) 138 | 139 | # sample num_patches from a discrete poisson distribution with upper limit max_patches 140 | num_patches = 1000 141 | while num_patches > self.max_patches: 142 | num_patches = np.random.poisson(1) 143 | if self.always_mask: 144 | num_patches = max(num_patches, 1) 145 | # sample num_patches starting points for the masked positions (+ final position) 146 | start_patches = list(np.sort(np.random.choice(np.arange(t[0] + 1, t[1]), 147 | num_patches, 148 | replace=False))) + [t[1]] 149 | # sample num_patches lengths of the patches 150 | len_patches = [sample_lengths(start_patches[i], start_patches[i + 1]) 151 | for i in range(len(start_patches) - 1)] 152 | # create masked tuples with start and end indices of the patches 153 | masked_tuples = [(start_patches[i], start_patches[i] + len_patches[i]) for i in range(len(start_patches) - 1)] 154 | # split the sequences into unmasked and masked parts 155 | unmasked_sequence, masked_sequence, unmasked_position_ids, masked_position_ids = self.split_sequences(sequences, 156 | t, 157 | masked_tuples) 158 | 159 | if self.troubleshoot: 160 | print(f"For sequence in {t}: sampled {num_patches=}, {start_patches=}, {len_patches=}, {masked_tuples=}") 161 | # concatenate the unmasked and masked parts 162 | return unmasked_sequence + masked_sequence, unmasked_position_ids + masked_position_ids if self.add_position_ids else None 163 | 164 | def split_sequences(self, sequences, t, masked_tuples): 165 | """ 166 | This function splits the sequences into unmasked and masked parts based on the given tuples. 167 | Args: 168 | t (tuple): The start and end index of each sequence. 169 | masked_tuples (list): A list of tuples specifying the indices for masked regions. 170 | Returns: 171 | unmasked_parts (list): The unmasked parts of the sequences interleaved with mask_tokens. 172 | masked_parts (list): The masked parts of the sequences interleaved with mask_tokens. 173 | """ 174 | unmasked_parts, masked_parts = [], [] 175 | unmasked_positions, masked_positions = [], [] 176 | position_ids = None 177 | start, end = t 178 | if self.add_position_ids: 179 | position_ids = np.arange(start, end) - start 180 | for i, region in enumerate(masked_tuples): 181 | mask_token = self.mask_tokens[f""] 182 | unmasked_parts.extend(sequences[start:region[0]]) 183 | unmasked_parts.append(mask_token) 184 | masked_parts.append(mask_token) 185 | masked_parts.extend(sequences[region[0]:region[1]]) 186 | if self.add_position_ids: 187 | unmasked_positions.extend(position_ids[start-t[0]:region[0]-t[0]]) 188 | unmasked_positions.append(position_ids[region[0]-t[0]]) 189 | masked_positions.append(position_ids[region[0]-t[0]]) 190 | masked_positions.extend(position_ids[region[0]-t[0]:region[1]-t[0]]) 191 | 192 | start = region[1] 193 | unmasked_parts.extend(sequences[start:end]) 194 | if self.add_position_ids: 195 | unmasked_positions.extend(position_ids[start-t[0]:end-t[0]]) 196 | if len(masked_tuples) > 0: 197 | unmasked_parts.append(self.eos_token) 198 | if self.add_position_ids: 199 | unmasked_positions.append(0) 200 | return unmasked_parts, masked_parts, unmasked_positions, masked_positions 201 | -------------------------------------------------------------------------------- /ProtMamba_ssm/trainer.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer, TrainerCallback 2 | from .utils import * 3 | import re 4 | import torch 5 | import os 6 | 7 | class MambaTrainer(Trainer): 8 | """ 9 | Base HuggingFace Trainer used for training. 10 | 11 | from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py""" 12 | def __init__(self, compute_only_fim_loss, **kwargs,): 13 | super().__init__(**kwargs) 14 | self.compute_only_fim_loss = compute_only_fim_loss 15 | 16 | def compute_loss(self, model, inputs, return_outputs=False): 17 | input_ids = inputs.pop("input_ids") 18 | if "seq_position_ids" in inputs and "position_ids" in inputs: 19 | position_ids = inputs.pop("position_ids") 20 | seq_position_ids = inputs.pop("seq_position_ids") 21 | output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids) 22 | elif "position_ids" in inputs: 23 | position_ids = inputs.pop("position_ids") 24 | output = model(input_ids, position_ids=position_ids) 25 | else: 26 | output = model(input_ids) 27 | lm_logits = output.logits 28 | 29 | labels = input_ids.to(lm_logits.device) 30 | shift_logits = lm_logits[:, :-1, :].contiguous() 31 | labels = labels[:, 1:].contiguous() 32 | 33 | loss_fct = torch.nn.CrossEntropyLoss() 34 | if self.compute_only_fim_loss: 35 | # start and end tokens 36 | is_cls_tokens = (labels == AA_TO_ID[""]) 37 | is_eos_tokens = (labels == AA_TO_ID[""]) 38 | bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens) 39 | # include also the cls token 40 | bool_fim = bool_fim | is_cls_tokens 41 | inds = torch.where(bool_fim) 42 | lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim]) 43 | else: 44 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) 45 | 46 | return (lm_loss, output) if return_outputs else lm_loss 47 | 48 | def save_model(self, output_dir, _internal_call): 49 | if not os.path.exists(output_dir): 50 | os.makedirs(output_dir) 51 | self.model.save_pretrained(output_dir) 52 | 53 | PREFIX_CHECKPOINT_DIR = "checkpoint" 54 | _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") 55 | 56 | def get_last_checkpoint(folder, max_steps=None): 57 | content = os.listdir(folder) 58 | checkpoints = [ 59 | path 60 | for path in content 61 | if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) 62 | ] 63 | if len(checkpoints) == 0: 64 | return 65 | 66 | max_steps = max_steps if max_steps is not None else float("inf") 67 | # func = lambda x: int(_re_checkpoint.search(x).groups()[0]) 68 | def func(x): 69 | num = int(_re_checkpoint.search(x).groups()[0]) 70 | return num if num < max_steps else -1 71 | return os.path.join(folder, max(checkpoints, key=func)) 72 | 73 | class EarlyStoppingCallback(TrainerCallback): 74 | def __init__(self, train_path, config=None): 75 | self.step_counter = 0 76 | self.best_loss = None 77 | self.train_path = train_path 78 | self.patience = config["patience"] 79 | self.metric_name = config["early_stopping_metric"] 80 | self.checkpoint_path = None 81 | self.should_restart = False 82 | self.eval_steps = config["eval_steps"] 83 | self.loss_increase_factor = config["loss_increase_factor"] 84 | 85 | def get_checkpoint_path(self, max_steps): 86 | last_checkpoint = None 87 | if os.path.exists(self.train_path): 88 | last_checkpoint = get_last_checkpoint(self.train_path, max_steps) 89 | if last_checkpoint is None: 90 | print("No checkpoint found, starting training from scratch.") 91 | else: 92 | print(f"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.") 93 | return last_checkpoint 94 | 95 | def on_evaluate(self, args, state, control, model, metrics, **kwargs): 96 | if self.metric_name in metrics: 97 | if self.best_loss is None: 98 | self.best_loss = metrics[self.metric_name] 99 | elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]: 100 | self.step_counter += 1 101 | if self.step_counter >= self.patience: 102 | checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps)) 103 | control.should_training_stop = True 104 | self.checkpoint_path = checkpoint_path 105 | self.should_restart = True 106 | else: 107 | self.step_counter = 0 108 | self.best_loss = min(self.best_loss, metrics[self.metric_name]) 109 | self.should_restart = False 110 | 111 | def on_train_begin(self, args, state, control, **kwargs): 112 | self.step_counter = 0 113 | self.best_loss = None 114 | self.should_restart = False 115 | -------------------------------------------------------------------------------- /ProtMamba_ssm/utils.py: -------------------------------------------------------------------------------- 1 | __all__ = ['AA_TO_ID', 'MASK_TO_ID', 'ID_TO_AA', 'encode_sequence', 'decode_sequence', 'clean_sequence', 'tokenizer', 2 | 'reorder_masked_sequence', 'load_from_file', 'generate_sequence', 'prepare_dataset_for_fim_generation', 3 | 'prepare_tokens', 'prepare_target', 'load_tensorboard_data', 'filter_datapoints', 'save_to_tensorboard', 4 | 'merge_loggings', 'concatenate_loggings', 'print_number_of_parameters', 'find_fim_indices', 5 | 'compute_metrics'] 6 | 7 | # Constants 8 | AA_TO_ID = {'': 0, 9 | '': 1, 10 | '': 2, 11 | '': 3, 12 | 'L': 4, 13 | 'A': 5, 14 | 'G': 6, 15 | 'V': 7, 16 | 'S': 8, 17 | 'E': 9, 18 | 'R': 10, 19 | 'T': 11, 20 | 'I': 12, 21 | 'D': 13, 22 | 'P': 14, 23 | 'K': 15, 24 | 'Q': 16, 25 | 'N': 17, 26 | 'F': 18, 27 | 'Y': 19, 28 | 'M': 20, 29 | 'H': 21, 30 | 'W': 22, 31 | 'C': 23, 32 | 'X': 24, 33 | 'B': 25, 34 | 'U': 26, 35 | 'Z': 27, 36 | 'O': 28, 37 | '.': 29, 38 | '-': 30, 39 | '': 31, 40 | '': 32} 41 | 42 | MASK_TO_ID = {"": 33, 43 | "": 34, 44 | "": 35, 45 | "": 36, 46 | "": 37,} 47 | 48 | AA_TO_ID.update(MASK_TO_ID) 49 | 50 | ID_TO_AA = {v: k for k, v in AA_TO_ID.items()} 51 | 52 | import numpy as np 53 | import torch 54 | from Bio import SeqIO 55 | 56 | # Encoder 57 | def encode_sequence(sequence): 58 | """Tokenize a sequence of amino acids and add a cls token at the beginning.""" 59 | tokenized_sequence = [AA_TO_ID[aa] if aa in AA_TO_ID else AA_TO_ID[''] for aa in sequence] 60 | return [AA_TO_ID['']] + tokenized_sequence 61 | 62 | def decode_sequence(sequence): 63 | """Decode a sequence of tokens.""" 64 | return "".join([ID_TO_AA[token] if token in ID_TO_AA else "" for token in sequence]) 65 | 66 | def clean_sequence(sequence): 67 | """Remove gaps and convert all residues to upper case.""" 68 | return sequence.replace("-", "").upper() 69 | 70 | def tokenizer(sequence_list, concatenate=True): 71 | """Tokenize a collection of sequences. If the sequences are aligned, the gaps will be removed 72 | and the insertions (lower case) will be promoted to upper case.""" 73 | # clean and encode all sequences 74 | sequence_list = [encode_sequence(clean_sequence(sequence)) for sequence in sequence_list] 75 | if concatenate: 76 | # concatenate all sequences 77 | sequences = np.concatenate(sequence_list) 78 | # convert to tensor and add batch dimension 79 | return torch.asarray(sequences, dtype=torch.int64)[None,:] 80 | else: 81 | return [torch.asarray(sequence, dtype=torch.int64) for sequence in sequence_list] 82 | 83 | 84 | def reorder_masked_sequence(mask_seq, return_ids=False): 85 | """ 86 | Reorder a masked sequence to fill the masked positions with the tokens 87 | that should be there but are positioned after the token. 88 | """ 89 | mask_seq = mask_seq.split("")[0] 90 | try: 91 | # Split the sequence and masks 92 | seq, masks = mask_seq.split("") 93 | except: 94 | return mask_seq 95 | full_seq = "" 96 | ids_mask = [] 97 | # Iterate over each mask tag 98 | for mm in ["", "", "", "", "",""]: 99 | try: 100 | # Split the sequence in before and after the mask tag 101 | seq1, seq2 = seq.split(mm) 102 | if mm=="": 103 | # If the mask is the first one, add the sequence before the mask and update the masks 104 | masks = masks.split("")[1] 105 | full_seq += seq1 106 | else: 107 | # If the mask is not the first one, insert the mask between the two sequence parts 108 | masks1, masks2 = masks.split(mm) 109 | ids_mask += [(len(full_seq), len(full_seq)+len(masks1))] 110 | full_seq += masks1 + seq1 111 | # Update the masks 112 | masks = masks2 113 | # Update the sequence with the part after the mask 114 | seq = seq2 115 | except: 116 | # If the mask is not found, add the remaining sequence 117 | ids_mask += [(len(full_seq), len(full_seq)+len(masks))] 118 | full_seq += masks + seq 119 | break 120 | if return_ids: 121 | return full_seq, ids_mask 122 | return full_seq 123 | 124 | def load_from_file(file_path): 125 | """Load a collection of sequences from an a3m file.""" 126 | with open(file_path, "r") as f: 127 | sequences = [str(record.seq) for record in SeqIO.parse(f, "fasta")] 128 | return sequences 129 | 130 | def generate_sequence(model, tokens, position_ids=None, seq_position_ids=None, is_fim=False, max_length=1, 131 | temperature=1., top_p=0.0, top_k=1, 132 | return_dict_in_generate=False, output_scores=False, teacher_outputs=None, 133 | eos_token_id=AA_TO_ID[""], 134 | device="cuda"): 135 | """Generating, either greedy or with top-k or top-p sampling. 136 | If top-k = 0, don't limit the number of candidates (pure sampling). 137 | Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, 138 | then top-p. We assume that all sequences in the same batch have the same length. 139 | """ 140 | input_ids = tokens.to(device) 141 | position_ids = position_ids.to(device) if position_ids is not None else None 142 | seq_position_ids = seq_position_ids.to(device) if seq_position_ids is not None else None 143 | teacher_outputs = teacher_outputs.to(device) if teacher_outputs is not None else None 144 | # generate sequence 145 | out = model.generate(input_ids=input_ids, 146 | position_ids=position_ids.clamp(0, 2000), 147 | seq_position_ids=seq_position_ids, 148 | is_fim=is_fim, 149 | max_length=max_length, 150 | temperature=temperature, 151 | top_p=top_p, 152 | top_k=top_k, 153 | return_dict_in_generate=return_dict_in_generate, 154 | output_scores=output_scores, 155 | eos_token_id=eos_token_id, 156 | teacher_outputs=teacher_outputs) 157 | sequences = out.sequences 158 | dic = {"input": [decode_sequence(seq) for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()], 159 | "generated": [decode_sequence(seq) for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()], 160 | "input_tokens": [seq for seq in sequences[:, :input_ids.shape[-1]].cpu().numpy()], 161 | "generated_tokens": [seq for seq in sequences[:, input_ids.shape[-1]:].cpu().numpy()]} 162 | if output_scores: 163 | dic["scores"] = np.array([el.to(torch.float32).cpu().numpy() for el in out.scores]).transpose(1, 0, 2) 164 | return dic 165 | 166 | 167 | def prepare_dataset_for_fim_generation(tokens, pos_ids): 168 | """ 169 | Function to transform the tokenized training dataset into a format that can be used for FIM generation. 170 | Splits the input tokens and pos_ids into the FIM part (of the last sequence) and the context part (all 171 | the previous sequences and the masked part of the last sequence). 172 | Also returns a dictionary with the positions of the mask tokens in the FIM part. 173 | """ 174 | def find_mask_positions(tokens_fim): 175 | """ 176 | Function to find the positions of the mask tokens in the FIM part of the last sequence. 177 | """ 178 | bool_mask = None 179 | inds_masks = [] 180 | for ind in MASK_TO_ID.values(): 181 | tmp_bool = tokens_fim[0].cpu().numpy() == ind 182 | bool_mask = tmp_bool if bool_mask is None else bool_mask | tmp_bool 183 | inds_masks += [ind] 184 | return bool_mask, inds_masks 185 | # find where the FIM part of the last sequence starts 186 | start_last_fim = np.where(tokens[0].cpu().numpy() == AA_TO_ID[""])[0][-1] 187 | start_next_seqs = np.where(tokens[0,start_last_fim+1:].cpu().numpy() == AA_TO_ID[""])[0] 188 | end_last_fim = start_last_fim+ 1 +start_next_seqs[0] if len(start_next_seqs) > 0 else tokens.shape[1] 189 | # split tokens and pos_ids into FIM part and context part 190 | tokens_to_fim = tokens[:,:start_last_fim+1] 191 | pos_ids_to_fim = pos_ids[:,:start_last_fim+1] 192 | tokens_fim = tokens[:,start_last_fim+1:end_last_fim] 193 | pos_ids_fim = pos_ids[:,start_last_fim+1:end_last_fim] 194 | # find positions of mask tokens 195 | bool_mask, inds_masks = find_mask_positions(tokens_fim) 196 | masked_positions = pos_ids_fim[0,bool_mask] 197 | mask_dict = {ind: int(pos) for ind, pos in zip(inds_masks, masked_positions)} 198 | return tokens_to_fim, pos_ids_to_fim, tokens_fim, pos_ids_fim, mask_dict 199 | 200 | def prepare_tokens(context_tokens, 201 | target_tokens, 202 | target_pos_ids, 203 | DatasetClass, 204 | num_sequences=1, 205 | fim_strategy="no-scramble", # "multiple_span" 206 | mask_fraction=0.2, 207 | max_patches=5, 208 | add_position_ids="1d", 209 | shuffle=True): 210 | """Prepare the tokens for the model by applying the FIM strategy and masking the tokens. 211 | It uses custom tokenized sequences and position ids.""" 212 | 213 | data_class = DatasetClass(None, 214 | fim_strategy=fim_strategy, 215 | mask_fraction=mask_fraction, 216 | max_patches=max_patches, 217 | add_position_ids=add_position_ids) 218 | if len(context_tokens) >= 1: 219 | seq, pos_ids = data_class.sample_sequences(context_tokens.numpy()[0], num_sequences=num_sequences, shuffle=shuffle) 220 | # convert to tensor and add batch dimension 221 | seq = torch.asarray(seq, dtype=torch.int64)[None,:] 222 | pos_ids = torch.asarray(pos_ids, dtype=torch.int64)[None,:] 223 | seq = torch.cat([seq, target_tokens], dim=1) 224 | pos_ids = torch.cat([pos_ids, target_pos_ids], dim=1) 225 | return seq, pos_ids 226 | else: 227 | return target_tokens, target_pos_ids 228 | 229 | def prepare_target(target, use_fim=None): 230 | """Prepare the target sequence for the model using a custom tokenized sequence. 231 | use_fim is a dictionary with the positions that should be masked. 232 | use_fim = {"": 1} _-> start to generate autoregressively from 1st position 233 | use_fim = {"": 10} -> start to generate autoregressively from 10th position 234 | use_fim = {"": ((10,13), 6)} -> mask positions from 10 to 13 (i.e. 10,11,12) and fill it with 6 tokens, 235 | use_fim = {"": ((10,13), 6), "": ((15,20), 2)} -> mask positions from 10 to 13 and 15 to 20 and fill it with 6 and 2 tokens 236 | """ 237 | if "" in use_fim: 238 | target = target[:,:use_fim[""]] 239 | pos_ids = torch.arange(target.shape[1], dtype=torch.int64)[None,:] 240 | assert target.shape[1] == pos_ids.shape[1] 241 | return target, pos_ids 242 | else: 243 | is_fim_dict = {} 244 | pos_ids = torch.arange(target.shape[1], dtype=torch.int64)[None,:] # default position ids 245 | diff_length = 0 246 | for mask in use_fim: 247 | assert "mask" in mask 248 | mask_positions, length = use_fim[mask] 249 | # update mask_positions to take into account the inserted parts 250 | mask_positions = (mask_positions[0]-diff_length, mask_positions[1]-diff_length) 251 | diff_length = mask_positions[1] - mask_positions[0] 252 | new_target = torch.cat([target[:,:mask_positions[0]], 253 | torch.full((target.shape[0], 1), AA_TO_ID[mask], dtype=torch.int64), 254 | target[:,mask_positions[1]:]], dim=1) 255 | new_pos_ids = torch.cat([pos_ids[:,:mask_positions[0]+1], 256 | pos_ids[:,mask_positions[1]:]+length-diff_length], dim=1) 257 | is_fim_dict[AA_TO_ID[mask]] = pos_ids[:,mask_positions[0]].squeeze().item() 258 | target = new_target 259 | pos_ids = new_pos_ids 260 | diff_length -= 1 261 | 262 | new_target = torch.cat([target, 263 | torch.full((target.shape[0], 1), AA_TO_ID[""], dtype=torch.int64)], dim=1) 264 | new_pos_ids = torch.cat([pos_ids, 265 | torch.full((target.shape[0], 1), 0, dtype=torch.int64)], dim=1) 266 | assert new_target.shape[1] == new_pos_ids.shape[1] 267 | return new_target, new_pos_ids, is_fim_dict 268 | 269 | from tensorboard.backend.event_processing import event_accumulator 270 | from tensorboard.backend.event_processing.event_accumulator import ScalarEvent 271 | from torch.utils.tensorboard import SummaryWriter 272 | import glob 273 | import matplotlib.pyplot as plt 274 | 275 | def load_tensorboard_data(path): 276 | """ 277 | Load the TensorBoard data 278 | """ 279 | ea = event_accumulator.EventAccumulator(path) 280 | ea.Reload() 281 | # Assuming you're interested in scalar summaries; adjust if otherwise 282 | # list all tags 283 | tags = ea.Tags()['scalars'] 284 | 285 | scalars = {tag: ea.scalars.Items(tag) for tag in tags} 286 | return scalars 287 | 288 | def filter_datapoints(scalars, condition): 289 | """ 290 | Filter out the datapoint you want to delete (customize this logic) 291 | """ 292 | # Example condition: lambda x: x.step != step_to_delete 293 | return {tag: [s for s in scalars if not condition(s)] for tag, scalars in scalars.items()} 294 | 295 | def save_to_tensorboard(filtered_data, output_path): 296 | """ 297 | Save modified data back to a new TensorBoard file 298 | """ 299 | writer = SummaryWriter(output_path) 300 | for tag, scalars in filtered_data.items(): 301 | for data in scalars: 302 | writer.add_scalar(tag, data.value, global_step=data.step, walltime=data.wall_time) 303 | 304 | def merge_loggings(directory, output_path, plot_metric=None): 305 | """ 306 | Merge all the TensorBoard files in a directory into a single file. 307 | Keeps only the metrics with latest wall time for each step (i.e. the last logged value for each step) 308 | """ 309 | # Find all the TensorBoard files 310 | def find_files(directory, pattern): 311 | return glob.glob(f"{directory}/{pattern}") 312 | all_paths = find_files(directory, "events.out.tfevents.*") 313 | # Merge all the data 314 | best_wall_times = {} 315 | updated_metrics = {} 316 | for elem in all_paths: 317 | try: 318 | # Load the data from one logging 319 | scalars = load_tensorboard_data(elem) 320 | # Make a dictionary with step number as key 321 | all_metrics = {k: {s.step: s for s in scalars[k]} for k in scalars.keys()} 322 | if plot_metric is not None: 323 | plt.plot([s.wall_time for s in scalars[plot_metric]], [s.value for s in scalars[plot_metric]]) 324 | # iterate over all the metrics 325 | for k in all_metrics.keys(): 326 | if k not in updated_metrics.keys(): 327 | updated_metrics[k] = {} 328 | if k not in best_wall_times: 329 | best_wall_times[k] = {} 330 | # Get wall time of each step 331 | steps_time = {step: s.wall_time for step,s in all_metrics[k].items()} 332 | # iterate over steps and pick only the metrics associated with the best wall time 333 | for key, value in steps_time.items(): 334 | if key not in updated_metrics[k]: 335 | best_wall_times[k][key] = value 336 | updated_metrics[k][key] = all_metrics[k][key] 337 | elif value > best_wall_times[k][key]: 338 | best_wall_times[k][key] = value 339 | updated_metrics[k][key] = all_metrics[k][key] 340 | else: 341 | continue 342 | except: 343 | print("Could not load.\t", elem.split("/")[-1]) 344 | # Sort the metrics by step 345 | new_logging = {k: list(updated_metrics[k].values()) for k in updated_metrics.keys()} 346 | new_logging = {k: sorted(v, key=lambda x: x.step) for k, v in new_logging.items()} 347 | # Save the merged data 348 | save_to_tensorboard(new_logging, output_path) 349 | plt.title(f"Metric: {plot_metric}") 350 | plt.show() 351 | return new_logging 352 | 353 | def concatenate_loggings(logging1_path, logging2_path, step_range1, step_range2, output_path): 354 | """ 355 | Concatenate the two loggings, assuming they have the same metrics. Use steps from step_range1[0] to step_range1[1] 356 | for logging1 and from step_range2[0] to step_range2[1] for logging2. 357 | Change the step numbers of logging2 to be continuous with logging1. and verify that the steps taken in each logging 358 | are the ones specified by step_range1 and step_range2. 359 | """ 360 | logging1 = load_tensorboard_data(logging1_path) 361 | logging2 = load_tensorboard_data(logging2_path) 362 | if step_range1 is None: 363 | k = list(logging1.keys())[0] 364 | step_range1 = (logging1[k][0].step, logging1[k][-1].step) 365 | if step_range2 is None: 366 | k = list(logging2.keys())[0] 367 | step_range2 = (logging2[k][0].step, logging2[k][-1].step) 368 | new_logging = {} 369 | for key in logging1.keys(): 370 | new_logging[key] = [el for el in logging1[key] if el.step >= step_range1[0] and el.step < step_range1[1]] 371 | for el in logging2[key]: 372 | if el.step >= step_range2[0] and el.step <= step_range2[1]: 373 | # not possible to assign to the step attribute of the object, make a new ScalarEvent object identical 374 | # to el but with the step attribute changed 375 | new_step_value = el.step - step_range2[0] + step_range1[1] 376 | new_el = ScalarEvent(step=new_step_value, wall_time=el.wall_time, value=el.value) 377 | 378 | new_logging[key].append(new_el) 379 | new_logging[key] = sorted(new_logging[key], key=lambda x: x.step) 380 | save_to_tensorboard(new_logging, output_path) 381 | return new_logging 382 | 383 | # %% ../nbs/99_utils.ipynb 6 384 | def print_number_of_parameters(model): 385 | print("Number of trainable parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad)) 386 | 387 | def find_fim_indices(is_cls_tokens, is_eos_tokens): 388 | """Function to find the indices of the FIM tokens in the sequences. 389 | """ 390 | # add a cls token at the beginning 391 | is_cls_tokens = torch.cat([torch.ones_like(is_cls_tokens[:, :1]), is_cls_tokens], dim=1) 392 | is_eos_tokens = torch.cat([torch.zeros_like(is_eos_tokens[:, :1]), is_eos_tokens], dim=1) 393 | # both eos and cls tokens 394 | bol = is_cls_tokens | is_eos_tokens 395 | tmp = torch.zeros_like(is_cls_tokens, dtype=torch.int) 396 | tmp[torch.nonzero(is_cls_tokens, as_tuple=True)] = 1 397 | tmp[torch.nonzero(is_eos_tokens, as_tuple=True)] = -1 398 | bol1 = torch.clone(bol) 399 | for batch_ind in range(tmp.size(0)): 400 | tmp1 = tmp[batch_ind,bol[batch_ind]] 401 | # find all positions where a 1 if preceeded by a -1 402 | tmp1 = tmp1[:-1]*tmp1[1:] 403 | # add the first element to make the sequence start with a 1 404 | tmp1 = torch.cat([torch.ones_like(tmp1[:1]).to(tmp1.device), tmp1]) 405 | new_bol = tmp1<0 406 | # bool array True only in the positions where a 1 is preceeded by a -1 407 | bol1[batch_ind,bol[batch_ind]] = False if new_bol.size(0) == 0 else new_bol 408 | cumulative_sum = torch.cumsum(bol1, dim=1) 409 | # Use modulo operation to get the desired tensor 410 | bol2 = cumulative_sum % 2 == 1 411 | bol2[is_eos_tokens]= False 412 | return bol2[:,1:] 413 | 414 | def compute_metrics(eval_pred): 415 | predictions, labels = eval_pred 416 | predictions = torch.tensor(predictions).permute(0, 2, 1) 417 | labels = torch.tensor(labels) 418 | # shift labels to align them with predictions and remove last prediction to match the length 419 | predictions = predictions[:, :, :-1].contiguous() 420 | labels = labels[:, 1:].contiguous() 421 | # compute unreduced elementwise loss 422 | unreduced_loss = torch.nn.functional.cross_entropy(predictions, labels, reduction="none") 423 | # compute reconstruction accuracy 424 | reconstruction = (predictions.argmax(1) == labels) 425 | 426 | # start and end tokens 427 | is_cls_tokens = (labels == AA_TO_ID[""]) 428 | is_eos_tokens = (labels == AA_TO_ID[""]) 429 | # fill in the middle tokens 430 | if False: 431 | fim_tokens = torch.zeros(is_cls_tokens.size(0), is_cls_tokens.size(1), dtype=torch.bool) 432 | in_mask_vector = torch.zeros(is_cls_tokens.size(0), dtype=torch.bool) 433 | for j in range(is_cls_tokens.size(1)): 434 | in_mask_vector = in_mask_vector & ~is_cls_tokens[:, j] 435 | fim_tokens[:, j] = in_mask_vector 436 | in_mask_vector = in_mask_vector | is_eos_tokens[:, j] 437 | fim_tokens = find_fim_indices(is_cls_tokens, is_eos_tokens) 438 | 439 | number_sequences = torch.cumsum(torch.cat([torch.zeros(is_cls_tokens.size(0),1, dtype=torch.int32), is_cls_tokens[:,:-1]],1), -1) 440 | # fist, second and last sequence tokens 441 | first_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 0) 442 | second_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == 1) 443 | last_sequence_tokens = ((~fim_tokens & (labels < 33)) | fim_tokens) & (number_sequences == (number_sequences.max(1).values[:, None] - 1)) 444 | # end of mask tokens 445 | end_of_masks = (fim_tokens & (labels > 33)) | is_cls_tokens | is_eos_tokens 446 | return {"loss/all": torch.mean(unreduced_loss).item(), 447 | "loss/end_span": torch.mean(unreduced_loss[end_of_masks]).item(), 448 | "perplexity/seq": torch.mean(torch.exp(torch.mean(unreduced_loss, dim=1))).item(), 449 | "perplexity/end_span": torch.exp(torch.mean(unreduced_loss[end_of_masks])).item(), 450 | "perplexity/batch": torch.exp(torch.mean(unreduced_loss)).item(), 451 | "perplexity/first_seq": torch.exp(torch.mean(unreduced_loss[first_sequence_tokens])).item(), 452 | "perplexity/second_seq": torch.exp(torch.mean(unreduced_loss[second_sequence_tokens])).item(), 453 | "perplexity/last_seq": torch.exp(torch.mean(unreduced_loss[last_sequence_tokens])).item(), 454 | "perplexity/fim": torch.exp(torch.mean(unreduced_loss[fim_tokens])).item(), 455 | "reconstruction/all": torch.mean(reconstruction.float()).item(), 456 | "reconstruction/end_span": torch.mean(reconstruction[end_of_masks].float()).item(), 457 | "reconstruction/first_seq": torch.mean(reconstruction[first_sequence_tokens].float()).item(), 458 | "reconstruction/second_seq": torch.mean(reconstruction[second_sequence_tokens].float()).item(), 459 | "reconstruction/last_seq": torch.mean(reconstruction[last_sequence_tokens].float()).item(), 460 | "reconstruction/fim": torch.mean(reconstruction[fim_tokens].float()).item(),} 461 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProtMamba 2 | 3 | 4 | 5 | **A Homology-Aware but Alignment-Free Protein State Space Model**: 6 | [ProtMamba](https://www.biorxiv.org/content/early/2024/05/25/2024.05.24.595730) 7 | is a novel protein language model designed to facilitate protein design. 8 | Unlike traditional models that rely on multiple sequence alignments 9 | (MSAs) to capture evolutionary information, ProtMamba can use unaligned 10 | homologous sequences, avoiding the imperfections associated with MSAs. 11 | 12 | 13 | 14 | ### Features 15 | 16 | ProtMamba is based on the Mamba architecture, a state space model that 17 | efficiently handles very long sequences. The model uses a 18 | fill-in-the-middle (FIM) training objective, combining autoregressive 19 | modeling and masked language modeling to predict amino acids conditioned 20 | on the given context sequences. This makes ProtMamba particularly 21 | well-suited for generating novel protein sequences, filling in specific 22 | regions of sequences, and predicting the fitness of protein variants. 23 | 24 | - **Homology-Aware but Alignment-Free**: Captures evolutionary 25 | information without relying on MSAs and use it to condition the 26 | generation process. 27 | - **Efficient Long-Context handling**: Uses Mamba blocks to handle long 28 | sequences with linear memory scaling. 29 | - **Different training objective**: Combines autoregressive and masked 30 | language modeling through a fill-in-the-middle objective. 31 | - **Sequence-level positional embeddings**: Enhances the model’s ability 32 | to reason about in-sequence dependencies and allow for precise 33 | inpainting. 34 | 35 | ### Applications 36 | 37 | - **Sequence Generation**: Generate novel protein sequences from scratch 38 | conditioned on specific homologs. 39 | - **Sequence Inpainting**: Fill in specific masked regions within a 40 | sequence for targeted protein design. 41 | - **Fitness Prediction**: Predict the probability distribution of 42 | mutations to assess the functional impact of variants. 43 | 44 | ### Repository Structure 45 | 46 | `configs/`: Configuration files for model training and evaluation. 47 | 48 | `data/`: Example dataset and full train/test split used to train the model (list of OpenProteinSet cluster names). 49 | 50 | `nbs/`: Implementation of the ProtMamba model architecture in jupyter 51 | notebooks. 52 | 53 | `ProtMamba_ssm/`: Implementation of the ProtMamba model architecture. 54 | 55 | `tests/`: Scripts to sample from ProtMamba and evaluate the model’s 56 | performance. 57 | 58 | ### Model weights 59 | 60 | The model weights are available in the latest release of the repository 61 | [here](https://github.com/Bitbol-Lab/ProtMamba-ssm/releases/tag/v1.0) 62 | 63 | ## Install Repository 64 | 65 | ``` sh 66 | pip install -e . 67 | ``` 68 | 69 | ## How to tokenize all your MSAs to make a training dataset 70 | 71 | IMPORTANT: the sequences should be in `a3m` files but they do not need 72 | to be aligned. 73 | 74 | ``` python 75 | from ProtMamba_ssm.dataloaders import * 76 | from ProtMamba_ssm.utils import * 77 | from ProtMamba_ssm.modules import * 78 | from train import run 79 | import torch 80 | from matplotlib import pyplot as plt 81 | import numpy as np 82 | ``` 83 | 84 | ``` python 85 | import pickle 86 | msa_paths = {"name-of-msa" : "../data/example_msa.a3m"} 87 | # path saving directory 88 | filepath = input("What is the path to the folder where you want to save the dataset?") # example: "../data/" 89 | dataset_name = input("How do you want to name the dataset file?") # example: "encoded_MSAs_train.pkl" 90 | 91 | dataset_dictionary = {} 92 | for msa_name, msa_path in msa_paths.items(): 93 | # Load an a3m file with all the context sequences 94 | msa = load_from_file(msa_path) 95 | # Tokenize the sequences and concatenate them into a single array 96 | tokens = tokenizer(msa, concatenate=True) 97 | tokens = tokens.numpy()[0] 98 | dataset_dictionary[msa_name] = tokens 99 | 100 | with open(filepath+dataset_name, "wb") as f: 101 | pickle.dump(dataset_dictionary, f) 102 | ``` 103 | 104 | ## How to train 105 | 106 | ``` python 107 | import yaml 108 | 109 | if int(use_one_gpu) >= 0: 110 | print(f"Using gpu {use_one_gpu}") 111 | print("Number of gpus used: ", torch.cuda.device_count()) 112 | 113 | # Load the default config file (change it and add the path to the training dataset) 114 | with open("../configs/default_config.yaml", "r") as file: 115 | defaultconfig = yaml.safe_load(file) 116 | 117 | namedir = input("Enter name of directory to save results: ") 118 | finetune_path = input("If you want to finetune a model, enter the relative path to the model's checkpoint, otherwise press enter:") 119 | finetune_path = finetune_path if finetune_path else None 120 | 121 | # Run the trainer with the selected training configuration 122 | trainer = run(defaultconfig, namedir, finetune_model_path=finetune_path) 123 | ``` 124 | 125 | ## How to sample from a pretrained model 126 | 127 | ``` python 128 | use_custom = input("Do you want to use a custom MSA? (y/n): ") 129 | if use_custom == "y": 130 | # Load an a3m file with all the context sequences 131 | msa = load_from_file("../data/example_msa.a3m") 132 | target_sequence = msa[:1] 133 | context_msa = msa[1:] 134 | # Tokenize the sequences and concatenate them into a single array 135 | target = tokenizer(target_sequence, concatenate=True) 136 | tokens = tokenizer(context_msa, concatenate=True) 137 | fim_gen = input("Do you want to generate using FIM? (y/n): ") 138 | if fim_gen=="n": 139 | # AUTOREGRESSIVE, no-FIM generation 140 | # generate the full sequence autoregressively starting from residue 10 in the sequence `target` 141 | gen_dictionary = {"": 10} 142 | input_seq, targ_pos = prepare_target(target, use_fim={"": 10}) 143 | if fim_gen=="y": 144 | # FIM generation 145 | # mask_dictionary is a dictionary of the positions in the sequence that you want to mask, in this example there will be 146 | # - a mask that covers the residues 4,5,6 and the model will fill it by sampling 10 residues 147 | # - a mask that covers the residues 30,31,32,33,34 and the model will fill it by sampling 3 residues 148 | mask_dictionary = {"": ((4,7),10),"": ((30,35),3)} 149 | input_seq, targ_pos, is_fim_dict = prepare_target(target, use_fim=mask_dictionary) 150 | context_tokens, context_pos_ids = prepare_tokens(tokens, 151 | target_tokens=input_seq, 152 | target_pos_ids=targ_pos, 153 | DatasetClass=Uniclust30_Dataset, 154 | num_sequences=50, 155 | fim_strategy="multiple_span", 156 | mask_fraction=0.2, 157 | max_patches=5, 158 | add_position_ids="1d") 159 | if use_custom == "n": 160 | is_fim = input("Do you want to use FIM? (y/n): ") 161 | filepath = input("What is the path to the folder with the dataset?") 162 | is_fim = True if is_fim == "y" else False 163 | # Load the dataset used for training 164 | dataset_name = input("What is the name of the dataset file?") # example: "encoded_MSAs_subset-100.pkl", "encoded_MSAs_train.pkl" 165 | fim_strategy = "multiple_span" if is_fim else "no-scramble" 166 | dataset = Uniclust30_Dataset(filename=dataset_name, 167 | filepath=filepath, 168 | sample=False, 169 | mask_fraction=0.2, 170 | fim_strategy=fim_strategy, 171 | max_position_embeddings=2048, 172 | add_position_ids="1d") 173 | # Select a sample of the dataset to be the input 174 | data = dataset[1] 175 | tokens = data["input_ids"][None,:].to("cuda") 176 | pos_ids = data["position_ids"][None,:].to("cuda") 177 | 178 | model_name = input("What is the path to the folder with the checkpoint of the model?") # example: "results/train_100M_FIM_restart-spikes_merged/checkpoint_131k-750" 179 | # Load pretrained model 180 | model = load_model(model_name, 181 | model_class=MambaLMHeadModelwithPosids, 182 | device="cuda", 183 | dtype=torch.bfloat16, 184 | checkpoint_mixer=False # Must be False when using model for Inference 185 | ) 186 | ``` 187 | 188 | ### Generate new sequences starting from a custom MSA 189 | 190 | ``` python 191 | print("Number of tokens in the MSA (target, context): ", len(input_seq[0]), ",", len(tokens[0])) 192 | print("Target:") 193 | print("sequence:", decode_sequence(input_seq[0].numpy())) 194 | print("original sequence:", decode_sequence(target[0].numpy())) 195 | print("pos ids:", list(targ_pos[0].numpy())) 196 | print("Context:") 197 | print("sequence:", decode_sequence(context_tokens[0].numpy())) 198 | print("pos ids:", list(context_pos_ids[0].numpy())) 199 | print("Mask positions:", is_fim_dict) 200 | ``` 201 | 202 | ``` python 203 | # Generate the new sequence 204 | output = generate_sequence(model, 205 | context_tokens, 206 | position_ids=context_pos_ids, 207 | is_fim=is_fim_dict, 208 | max_length=20000, 209 | temperature=1., 210 | top_k=3, 211 | top_p=0.0, 212 | return_dict_in_generate=True, 213 | output_scores=True, 214 | eos_token_id=AA_TO_ID[""], 215 | device="cuda") 216 | 217 | input_seq, output_seq = output["input"], output["generated"] 218 | logits = output["scores"] 219 | print(f"All input context (len = {len(input_seq[0])}):\n", input_seq[0]) 220 | print("Last sequence where the masked parts should be predicted:\n", input_seq[0].split("")[-1]) 221 | print(f"Generated (len = {len(output_seq[0])}):\n", output_seq[0]) 222 | input_continuation = decode_sequence(target[0].numpy())+"" 223 | print(f"Continuation of input:\n", input_continuation) 224 | print(f"\nLogits (shape = {logits.shape})") 225 | ``` 226 | 227 | ### Generate new sequences using the Dataset and by filling in the masked parts (FIM) 228 | 229 | ``` python 230 | context_tokens, context_pos_ids, tokens_fim, pos_ids_fim, is_fim_dict = prepare_dataset_for_fim_generation(tokens, pos_ids) 231 | print("Length of context information", context_tokens.shape[1]) 232 | print("Masked part of the sequence to predict and its position indices:\n", tokens_fim, "\n", pos_ids_fim) 233 | print("Masked tokens and their positions in the input sequence:", is_fim_dict) 234 | ``` 235 | 236 | ``` python 237 | # Generate the new sequence 238 | output = generate_sequence(model, 239 | context_tokens, 240 | position_ids=context_pos_ids, 241 | is_fim=is_fim_dict, 242 | max_length=1570, 243 | temperature=1., 244 | top_k=3, 245 | top_p=0.0, 246 | return_dict_in_generate=True, 247 | output_scores=True, 248 | eos_token_id=AA_TO_ID[""], 249 | device="cuda") 250 | 251 | input_seq, output_seq = output["input"], output["generated"] 252 | logits = output["scores"] 253 | print(f"All input context (len = {len(input_seq[0])}):\n", input_seq[0]) 254 | print("Last sequence where the masked parts should be predicted:\n", input_seq[0].split("")[-1]) 255 | print(f"Generated (len = {len(output_seq[0])}):\n", output_seq[0]) 256 | input_continuation = decode_sequence(tokens_fim[0].cpu().numpy())+"" 257 | print(f"Continuation of input:\n", input_continuation) 258 | print(f"\nLogits (shape = {logits.shape})") 259 | ``` 260 | 261 | ``` python 262 | total = len(tokens_fim[0])+1 263 | if total>4: 264 | fig, axs = plt.subplots(total//4,4, figsize=(20,5*total//4)) 265 | else: 266 | fig, axs = plt.subplots(1,total, figsize=(20,total,5)) 267 | axs = axs[None,:] 268 | for el in range(total): 269 | ax = axs[el//4,el%4] 270 | ax.bar(np.arange(logits.shape[-1]), 271 | torch.softmax(torch.from_numpy(logits[0,el,:]), dim=0)) 272 | ax.axvline(output["generated_tokens"][0][el], color="red", label="Prediction: "+ID_TO_AA[output["generated_tokens"][0][el]] + f" ({output['generated_tokens'][0][el]})", linewidth=0.5) 273 | # ax.axvline(tokens_fim[0,el].cpu().numpy(), color="k",label="Original: "+ID_TO_AA[tokens_fim[0,el].cpu().numpy()] +f" ({tokens_fim[0,el].cpu().numpy()})", linewidth=0.5) 274 | ax.legend() 275 | fig.suptitle(f"Real sequence: {input_continuation}\nPred sequence: {output_seq[0]}") 276 | plt.show() 277 | ``` 278 | 279 | ### Generate new sequences using the Dataset and by sampling amino acids autoregressively from `` 280 | 281 | ``` python 282 | # Generate the new sequence 283 | L = 650#628# 284 | output = generate_sequence(model, 285 | tokens[:,:L], 286 | position_ids=pos_ids[:,:L], 287 | is_fim=False, 288 | max_length=1570, 289 | temperature=1., 290 | top_k=10, 291 | top_p=0.0, 292 | return_dict_in_generate=True, 293 | output_scores=True, 294 | eos_token_id=torch.tensor([AA_TO_ID[""],AA_TO_ID[""], AA_TO_ID[""], AA_TO_ID[""], 295 | AA_TO_ID[""], AA_TO_ID[""]]).to("cuda"), 296 | device="cuda") 297 | 298 | input_seq, output_seq = output["input"], output["generated"] 299 | logits = output["scores"] 300 | 301 | print(f"All input context (len = {len(input_seq[0])}):\n", input_seq[0]) 302 | print("Last sequence where the masked parts should be predicted:\n", input_seq[0].split("")[-1]) 303 | print(f"Generated (len = {len(output_seq[0])}):\n", output_seq[0]) 304 | input_continuation = decode_sequence(tokens[0,L:].cpu().numpy()).split("")[0]+"" 305 | print(f"Continuation of input:\n", input_continuation) 306 | print(f"\nLogits (shape = {logits.shape})") 307 | print("\nStops if the model predicts a mask token") 308 | ``` 309 | 310 | ``` python 311 | fig, axs = plt.subplots(4,4, figsize=(20,10)) 312 | for el in range(16): 313 | ax = axs[el//4,el%4] 314 | ax.bar(np.arange(logits.shape[-1]), 315 | torch.softmax(torch.from_numpy(logits[0,el,:]), dim=0)) 316 | ax.axvline(output["generated_tokens"][0][el], color="red", label="Prediction: "+output_seq[0][el] + f" ({AA_TO_ID[output_seq[0][el]]})", linewidth=0.5) 317 | ax.axvline(tokens[0,L+el].cpu().numpy(), color="k",label="Original: "+input_continuation[el] +f" ({AA_TO_ID[input_continuation[el]]})", linewidth=0.5) 318 | ax.legend() 319 | fig.suptitle(f"Real sequence: {input_continuation[:16]}\nPred sequence: {output_seq[0][:16]}") 320 | plt.show() 321 | ``` 322 | 323 | ### Get the hidden states at each layer 324 | 325 | ``` python 326 | which_layers = list(range(1,model.config.n_layer+1)) 327 | 328 | def get_hidden_states(model, tokens, which_layers, position_ids=None, seq_position_ids=None): 329 | hidden_states = model(input_ids=tokens, 330 | save_layer=which_layers, 331 | position_ids=position_ids, 332 | seq_position_ids=seq_position_ids) 333 | return hidden_states 334 | 335 | hidden_states = get_hidden_states(model, tokens[:,:10], which_layers, pos_ids[:,:10]) 336 | print(f"Saved hidden states for layers: {hidden_states.keys()}") 337 | print(f"Shape of hidden state of layer 1: {hidden_states[1].shape}") 338 | ``` 339 | 340 | ## Citation 341 | 342 | To cite this work, please refer to the following publication: 343 | 344 | ``` bibtex 345 | @article{sgarbossa2024protmamba, 346 | title={{ProtMamba}: -- a homology-aware but alignment-free protein state space model}, 347 | author={Damiano Sgarbossa and Cyril Malbranke and Anne-Florence Bitbol}, 348 | journal={bioRxiv}, 349 | doi = {10.1101/2024.05.24.595730}, 350 | year={2024}, 351 | url={https://www.biorxiv.org/content/early/2024/05/25/2024.05.24.595730} 352 | } 353 | ``` 354 | 355 | ## nbdev 356 | 357 | Project developed using [nbdev](https://nbdev.fast.ai/). 358 | -------------------------------------------------------------------------------- /README_files/libs/bootstrap/bootstrap-icons.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bitbol-Lab/ProtMamba-ssm/4dfb7e199542a88babbc4f6b724b09dad8eeacbb/README_files/libs/bootstrap/bootstrap-icons.woff -------------------------------------------------------------------------------- /README_files/libs/clipboard/clipboard.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * clipboard.js v2.0.11 3 | * https://clipboardjs.com/ 4 | * 5 | * Licensed MIT © Zeno Rocha 6 | */ 7 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.ClipboardJS=e():t.ClipboardJS=e()}(this,function(){return n={686:function(t,e,n){"use strict";n.d(e,{default:function(){return b}});var e=n(279),i=n.n(e),e=n(370),u=n.n(e),e=n(817),r=n.n(e);function c(t){try{return document.execCommand(t)}catch(t){return}}var a=function(t){t=r()(t);return c("cut"),t};function o(t,e){var n,o,t=(n=t,o="rtl"===document.documentElement.getAttribute("dir"),(t=document.createElement("textarea")).style.fontSize="12pt",t.style.border="0",t.style.padding="0",t.style.margin="0",t.style.position="absolute",t.style[o?"right":"left"]="-9999px",o=window.pageYOffset||document.documentElement.scrollTop,t.style.top="".concat(o,"px"),t.setAttribute("readonly",""),t.value=n,t);return e.container.appendChild(t),e=r()(t),c("copy"),t.remove(),e}var f=function(t){var e=1.anchorjs-link,.anchorjs-link:focus{opacity:1}",u.sheet.cssRules.length),u.sheet.insertRule("[data-anchorjs-icon]::after{content:attr(data-anchorjs-icon)}",u.sheet.cssRules.length),u.sheet.insertRule('@font-face{font-family:anchorjs-icons;src:url(data:n/a;base64,AAEAAAALAIAAAwAwT1MvMg8yG2cAAAE4AAAAYGNtYXDp3gC3AAABpAAAAExnYXNwAAAAEAAAA9wAAAAIZ2x5ZlQCcfwAAAH4AAABCGhlYWQHFvHyAAAAvAAAADZoaGVhBnACFwAAAPQAAAAkaG10eASAADEAAAGYAAAADGxvY2EACACEAAAB8AAAAAhtYXhwAAYAVwAAARgAAAAgbmFtZQGOH9cAAAMAAAAAunBvc3QAAwAAAAADvAAAACAAAQAAAAEAAHzE2p9fDzz1AAkEAAAAAADRecUWAAAAANQA6R8AAAAAAoACwAAAAAgAAgAAAAAAAAABAAADwP/AAAACgAAA/9MCrQABAAAAAAAAAAAAAAAAAAAAAwABAAAAAwBVAAIAAAAAAAIAAAAAAAAAAAAAAAAAAAAAAAMCQAGQAAUAAAKZAswAAACPApkCzAAAAesAMwEJAAAAAAAAAAAAAAAAAAAAARAAAAAAAAAAAAAAAAAAAAAAQAAg//0DwP/AAEADwABAAAAAAQAAAAAAAAAAAAAAIAAAAAAAAAIAAAACgAAxAAAAAwAAAAMAAAAcAAEAAwAAABwAAwABAAAAHAAEADAAAAAIAAgAAgAAACDpy//9//8AAAAg6cv//f///+EWNwADAAEAAAAAAAAAAAAAAAAACACEAAEAAAAAAAAAAAAAAAAxAAACAAQARAKAAsAAKwBUAAABIiYnJjQ3NzY2MzIWFxYUBwcGIicmNDc3NjQnJiYjIgYHBwYUFxYUBwYGIwciJicmNDc3NjIXFhQHBwYUFxYWMzI2Nzc2NCcmNDc2MhcWFAcHBgYjARQGDAUtLXoWOR8fORYtLTgKGwoKCjgaGg0gEhIgDXoaGgkJBQwHdR85Fi0tOAobCgoKOBoaDSASEiANehoaCQkKGwotLXoWOR8BMwUFLYEuehYXFxYugC44CQkKGwo4GkoaDQ0NDXoaShoKGwoFBe8XFi6ALjgJCQobCjgaShoNDQ0NehpKGgobCgoKLYEuehYXAAAADACWAAEAAAAAAAEACAAAAAEAAAAAAAIAAwAIAAEAAAAAAAMACAAAAAEAAAAAAAQACAAAAAEAAAAAAAUAAQALAAEAAAAAAAYACAAAAAMAAQQJAAEAEAAMAAMAAQQJAAIABgAcAAMAAQQJAAMAEAAMAAMAAQQJAAQAEAAMAAMAAQQJAAUAAgAiAAMAAQQJAAYAEAAMYW5jaG9yanM0MDBAAGEAbgBjAGgAbwByAGoAcwA0ADAAMABAAAAAAwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABAAH//wAP) format("truetype")}',u.sheet.cssRules.length)),u=document.querySelectorAll("[id]"),t=[].map.call(u,function(A){return A.id}),i=0;i\]./()*\\\n\t\b\v\u00A0]/g,"-").replace(/-{2,}/g,"-").substring(0,this.options.truncate).replace(/^-+|-+$/gm,"").toLowerCase()},this.hasAnchorJSLink=function(A){var e=A.firstChild&&-1<(" "+A.firstChild.className+" ").indexOf(" anchorjs-link "),A=A.lastChild&&-1<(" "+A.lastChild.className+" ").indexOf(" anchorjs-link ");return e||A||!1}}}); 9 | // @license-end -------------------------------------------------------------------------------- /README_files/libs/quarto-html/popper.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @popperjs/core v2.11.4 - MIT License 3 | */ 4 | 5 | !function(e,t){"object"==typeof exports&&"undefined"!=typeof module?t(exports):"function"==typeof define&&define.amd?define(["exports"],t):t((e="undefined"!=typeof globalThis?globalThis:e||self).Popper={})}(this,(function(e){"use strict";function t(e){if(null==e)return window;if("[object Window]"!==e.toString()){var t=e.ownerDocument;return t&&t.defaultView||window}return e}function n(e){return e instanceof t(e).Element||e instanceof Element}function r(e){return e instanceof t(e).HTMLElement||e instanceof HTMLElement}function o(e){return"undefined"!=typeof ShadowRoot&&(e instanceof t(e).ShadowRoot||e instanceof ShadowRoot)}var i=Math.max,a=Math.min,s=Math.round;function f(e,t){void 0===t&&(t=!1);var n=e.getBoundingClientRect(),o=1,i=1;if(r(e)&&t){var a=e.offsetHeight,f=e.offsetWidth;f>0&&(o=s(n.width)/f||1),a>0&&(i=s(n.height)/a||1)}return{width:n.width/o,height:n.height/i,top:n.top/i,right:n.right/o,bottom:n.bottom/i,left:n.left/o,x:n.left/o,y:n.top/i}}function c(e){var n=t(e);return{scrollLeft:n.pageXOffset,scrollTop:n.pageYOffset}}function p(e){return e?(e.nodeName||"").toLowerCase():null}function u(e){return((n(e)?e.ownerDocument:e.document)||window.document).documentElement}function l(e){return f(u(e)).left+c(e).scrollLeft}function d(e){return t(e).getComputedStyle(e)}function h(e){var t=d(e),n=t.overflow,r=t.overflowX,o=t.overflowY;return/auto|scroll|overlay|hidden/.test(n+o+r)}function m(e,n,o){void 0===o&&(o=!1);var i,a,d=r(n),m=r(n)&&function(e){var t=e.getBoundingClientRect(),n=s(t.width)/e.offsetWidth||1,r=s(t.height)/e.offsetHeight||1;return 1!==n||1!==r}(n),v=u(n),g=f(e,m),y={scrollLeft:0,scrollTop:0},b={x:0,y:0};return(d||!d&&!o)&&(("body"!==p(n)||h(v))&&(y=(i=n)!==t(i)&&r(i)?{scrollLeft:(a=i).scrollLeft,scrollTop:a.scrollTop}:c(i)),r(n)?((b=f(n,!0)).x+=n.clientLeft,b.y+=n.clientTop):v&&(b.x=l(v))),{x:g.left+y.scrollLeft-b.x,y:g.top+y.scrollTop-b.y,width:g.width,height:g.height}}function v(e){var t=f(e),n=e.offsetWidth,r=e.offsetHeight;return Math.abs(t.width-n)<=1&&(n=t.width),Math.abs(t.height-r)<=1&&(r=t.height),{x:e.offsetLeft,y:e.offsetTop,width:n,height:r}}function g(e){return"html"===p(e)?e:e.assignedSlot||e.parentNode||(o(e)?e.host:null)||u(e)}function y(e){return["html","body","#document"].indexOf(p(e))>=0?e.ownerDocument.body:r(e)&&h(e)?e:y(g(e))}function b(e,n){var r;void 0===n&&(n=[]);var o=y(e),i=o===(null==(r=e.ownerDocument)?void 0:r.body),a=t(o),s=i?[a].concat(a.visualViewport||[],h(o)?o:[]):o,f=n.concat(s);return i?f:f.concat(b(g(s)))}function x(e){return["table","td","th"].indexOf(p(e))>=0}function w(e){return r(e)&&"fixed"!==d(e).position?e.offsetParent:null}function O(e){for(var n=t(e),i=w(e);i&&x(i)&&"static"===d(i).position;)i=w(i);return i&&("html"===p(i)||"body"===p(i)&&"static"===d(i).position)?n:i||function(e){var t=-1!==navigator.userAgent.toLowerCase().indexOf("firefox");if(-1!==navigator.userAgent.indexOf("Trident")&&r(e)&&"fixed"===d(e).position)return null;var n=g(e);for(o(n)&&(n=n.host);r(n)&&["html","body"].indexOf(p(n))<0;){var i=d(n);if("none"!==i.transform||"none"!==i.perspective||"paint"===i.contain||-1!==["transform","perspective"].indexOf(i.willChange)||t&&"filter"===i.willChange||t&&i.filter&&"none"!==i.filter)return n;n=n.parentNode}return null}(e)||n}var j="top",E="bottom",D="right",A="left",L="auto",P=[j,E,D,A],M="start",k="end",W="viewport",B="popper",H=P.reduce((function(e,t){return e.concat([t+"-"+M,t+"-"+k])}),[]),T=[].concat(P,[L]).reduce((function(e,t){return e.concat([t,t+"-"+M,t+"-"+k])}),[]),R=["beforeRead","read","afterRead","beforeMain","main","afterMain","beforeWrite","write","afterWrite"];function S(e){var t=new Map,n=new Set,r=[];function o(e){n.add(e.name),[].concat(e.requires||[],e.requiresIfExists||[]).forEach((function(e){if(!n.has(e)){var r=t.get(e);r&&o(r)}})),r.push(e)}return e.forEach((function(e){t.set(e.name,e)})),e.forEach((function(e){n.has(e.name)||o(e)})),r}function C(e){return e.split("-")[0]}function q(e,t){var n=t.getRootNode&&t.getRootNode();if(e.contains(t))return!0;if(n&&o(n)){var r=t;do{if(r&&e.isSameNode(r))return!0;r=r.parentNode||r.host}while(r)}return!1}function V(e){return Object.assign({},e,{left:e.x,top:e.y,right:e.x+e.width,bottom:e.y+e.height})}function N(e,r){return r===W?V(function(e){var n=t(e),r=u(e),o=n.visualViewport,i=r.clientWidth,a=r.clientHeight,s=0,f=0;return o&&(i=o.width,a=o.height,/^((?!chrome|android).)*safari/i.test(navigator.userAgent)||(s=o.offsetLeft,f=o.offsetTop)),{width:i,height:a,x:s+l(e),y:f}}(e)):n(r)?function(e){var t=f(e);return t.top=t.top+e.clientTop,t.left=t.left+e.clientLeft,t.bottom=t.top+e.clientHeight,t.right=t.left+e.clientWidth,t.width=e.clientWidth,t.height=e.clientHeight,t.x=t.left,t.y=t.top,t}(r):V(function(e){var t,n=u(e),r=c(e),o=null==(t=e.ownerDocument)?void 0:t.body,a=i(n.scrollWidth,n.clientWidth,o?o.scrollWidth:0,o?o.clientWidth:0),s=i(n.scrollHeight,n.clientHeight,o?o.scrollHeight:0,o?o.clientHeight:0),f=-r.scrollLeft+l(e),p=-r.scrollTop;return"rtl"===d(o||n).direction&&(f+=i(n.clientWidth,o?o.clientWidth:0)-a),{width:a,height:s,x:f,y:p}}(u(e)))}function I(e,t,o){var s="clippingParents"===t?function(e){var t=b(g(e)),o=["absolute","fixed"].indexOf(d(e).position)>=0&&r(e)?O(e):e;return n(o)?t.filter((function(e){return n(e)&&q(e,o)&&"body"!==p(e)})):[]}(e):[].concat(t),f=[].concat(s,[o]),c=f[0],u=f.reduce((function(t,n){var r=N(e,n);return t.top=i(r.top,t.top),t.right=a(r.right,t.right),t.bottom=a(r.bottom,t.bottom),t.left=i(r.left,t.left),t}),N(e,c));return u.width=u.right-u.left,u.height=u.bottom-u.top,u.x=u.left,u.y=u.top,u}function _(e){return e.split("-")[1]}function F(e){return["top","bottom"].indexOf(e)>=0?"x":"y"}function U(e){var t,n=e.reference,r=e.element,o=e.placement,i=o?C(o):null,a=o?_(o):null,s=n.x+n.width/2-r.width/2,f=n.y+n.height/2-r.height/2;switch(i){case j:t={x:s,y:n.y-r.height};break;case E:t={x:s,y:n.y+n.height};break;case D:t={x:n.x+n.width,y:f};break;case A:t={x:n.x-r.width,y:f};break;default:t={x:n.x,y:n.y}}var c=i?F(i):null;if(null!=c){var p="y"===c?"height":"width";switch(a){case M:t[c]=t[c]-(n[p]/2-r[p]/2);break;case k:t[c]=t[c]+(n[p]/2-r[p]/2)}}return t}function z(e){return Object.assign({},{top:0,right:0,bottom:0,left:0},e)}function X(e,t){return t.reduce((function(t,n){return t[n]=e,t}),{})}function Y(e,t){void 0===t&&(t={});var r=t,o=r.placement,i=void 0===o?e.placement:o,a=r.boundary,s=void 0===a?"clippingParents":a,c=r.rootBoundary,p=void 0===c?W:c,l=r.elementContext,d=void 0===l?B:l,h=r.altBoundary,m=void 0!==h&&h,v=r.padding,g=void 0===v?0:v,y=z("number"!=typeof g?g:X(g,P)),b=d===B?"reference":B,x=e.rects.popper,w=e.elements[m?b:d],O=I(n(w)?w:w.contextElement||u(e.elements.popper),s,p),A=f(e.elements.reference),L=U({reference:A,element:x,strategy:"absolute",placement:i}),M=V(Object.assign({},x,L)),k=d===B?M:A,H={top:O.top-k.top+y.top,bottom:k.bottom-O.bottom+y.bottom,left:O.left-k.left+y.left,right:k.right-O.right+y.right},T=e.modifiersData.offset;if(d===B&&T){var R=T[i];Object.keys(H).forEach((function(e){var t=[D,E].indexOf(e)>=0?1:-1,n=[j,E].indexOf(e)>=0?"y":"x";H[e]+=R[n]*t}))}return H}var G={placement:"bottom",modifiers:[],strategy:"absolute"};function J(){for(var e=arguments.length,t=new Array(e),n=0;n=0?-1:1,i="function"==typeof n?n(Object.assign({},t,{placement:e})):n,a=i[0],s=i[1];return a=a||0,s=(s||0)*o,[A,D].indexOf(r)>=0?{x:s,y:a}:{x:a,y:s}}(n,t.rects,i),e}),{}),s=a[t.placement],f=s.x,c=s.y;null!=t.modifiersData.popperOffsets&&(t.modifiersData.popperOffsets.x+=f,t.modifiersData.popperOffsets.y+=c),t.modifiersData[r]=a}},ie={left:"right",right:"left",bottom:"top",top:"bottom"};function ae(e){return e.replace(/left|right|bottom|top/g,(function(e){return ie[e]}))}var se={start:"end",end:"start"};function fe(e){return e.replace(/start|end/g,(function(e){return se[e]}))}function ce(e,t){void 0===t&&(t={});var n=t,r=n.placement,o=n.boundary,i=n.rootBoundary,a=n.padding,s=n.flipVariations,f=n.allowedAutoPlacements,c=void 0===f?T:f,p=_(r),u=p?s?H:H.filter((function(e){return _(e)===p})):P,l=u.filter((function(e){return c.indexOf(e)>=0}));0===l.length&&(l=u);var d=l.reduce((function(t,n){return t[n]=Y(e,{placement:n,boundary:o,rootBoundary:i,padding:a})[C(n)],t}),{});return Object.keys(d).sort((function(e,t){return d[e]-d[t]}))}var pe={name:"flip",enabled:!0,phase:"main",fn:function(e){var t=e.state,n=e.options,r=e.name;if(!t.modifiersData[r]._skip){for(var o=n.mainAxis,i=void 0===o||o,a=n.altAxis,s=void 0===a||a,f=n.fallbackPlacements,c=n.padding,p=n.boundary,u=n.rootBoundary,l=n.altBoundary,d=n.flipVariations,h=void 0===d||d,m=n.allowedAutoPlacements,v=t.options.placement,g=C(v),y=f||(g===v||!h?[ae(v)]:function(e){if(C(e)===L)return[];var t=ae(e);return[fe(e),t,fe(t)]}(v)),b=[v].concat(y).reduce((function(e,n){return e.concat(C(n)===L?ce(t,{placement:n,boundary:p,rootBoundary:u,padding:c,flipVariations:h,allowedAutoPlacements:m}):n)}),[]),x=t.rects.reference,w=t.rects.popper,O=new Map,P=!0,k=b[0],W=0;W=0,S=R?"width":"height",q=Y(t,{placement:B,boundary:p,rootBoundary:u,altBoundary:l,padding:c}),V=R?T?D:A:T?E:j;x[S]>w[S]&&(V=ae(V));var N=ae(V),I=[];if(i&&I.push(q[H]<=0),s&&I.push(q[V]<=0,q[N]<=0),I.every((function(e){return e}))){k=B,P=!1;break}O.set(B,I)}if(P)for(var F=function(e){var t=b.find((function(t){var n=O.get(t);if(n)return n.slice(0,e).every((function(e){return e}))}));if(t)return k=t,"break"},U=h?3:1;U>0;U--){if("break"===F(U))break}t.placement!==k&&(t.modifiersData[r]._skip=!0,t.placement=k,t.reset=!0)}},requiresIfExists:["offset"],data:{_skip:!1}};function ue(e,t,n){return i(e,a(t,n))}var le={name:"preventOverflow",enabled:!0,phase:"main",fn:function(e){var t=e.state,n=e.options,r=e.name,o=n.mainAxis,s=void 0===o||o,f=n.altAxis,c=void 0!==f&&f,p=n.boundary,u=n.rootBoundary,l=n.altBoundary,d=n.padding,h=n.tether,m=void 0===h||h,g=n.tetherOffset,y=void 0===g?0:g,b=Y(t,{boundary:p,rootBoundary:u,padding:d,altBoundary:l}),x=C(t.placement),w=_(t.placement),L=!w,P=F(x),k="x"===P?"y":"x",W=t.modifiersData.popperOffsets,B=t.rects.reference,H=t.rects.popper,T="function"==typeof y?y(Object.assign({},t.rects,{placement:t.placement})):y,R="number"==typeof T?{mainAxis:T,altAxis:T}:Object.assign({mainAxis:0,altAxis:0},T),S=t.modifiersData.offset?t.modifiersData.offset[t.placement]:null,q={x:0,y:0};if(W){if(s){var V,N="y"===P?j:A,I="y"===P?E:D,U="y"===P?"height":"width",z=W[P],X=z+b[N],G=z-b[I],J=m?-H[U]/2:0,K=w===M?B[U]:H[U],Q=w===M?-H[U]:-B[U],Z=t.elements.arrow,$=m&&Z?v(Z):{width:0,height:0},ee=t.modifiersData["arrow#persistent"]?t.modifiersData["arrow#persistent"].padding:{top:0,right:0,bottom:0,left:0},te=ee[N],ne=ee[I],re=ue(0,B[U],$[U]),oe=L?B[U]/2-J-re-te-R.mainAxis:K-re-te-R.mainAxis,ie=L?-B[U]/2+J+re+ne+R.mainAxis:Q+re+ne+R.mainAxis,ae=t.elements.arrow&&O(t.elements.arrow),se=ae?"y"===P?ae.clientTop||0:ae.clientLeft||0:0,fe=null!=(V=null==S?void 0:S[P])?V:0,ce=z+ie-fe,pe=ue(m?a(X,z+oe-fe-se):X,z,m?i(G,ce):G);W[P]=pe,q[P]=pe-z}if(c){var le,de="x"===P?j:A,he="x"===P?E:D,me=W[k],ve="y"===k?"height":"width",ge=me+b[de],ye=me-b[he],be=-1!==[j,A].indexOf(x),xe=null!=(le=null==S?void 0:S[k])?le:0,we=be?ge:me-B[ve]-H[ve]-xe+R.altAxis,Oe=be?me+B[ve]+H[ve]-xe-R.altAxis:ye,je=m&&be?function(e,t,n){var r=ue(e,t,n);return r>n?n:r}(we,me,Oe):ue(m?we:ge,me,m?Oe:ye);W[k]=je,q[k]=je-me}t.modifiersData[r]=q}},requiresIfExists:["offset"]};var de={name:"arrow",enabled:!0,phase:"main",fn:function(e){var t,n=e.state,r=e.name,o=e.options,i=n.elements.arrow,a=n.modifiersData.popperOffsets,s=C(n.placement),f=F(s),c=[A,D].indexOf(s)>=0?"height":"width";if(i&&a){var p=function(e,t){return z("number"!=typeof(e="function"==typeof e?e(Object.assign({},t.rects,{placement:t.placement})):e)?e:X(e,P))}(o.padding,n),u=v(i),l="y"===f?j:A,d="y"===f?E:D,h=n.rects.reference[c]+n.rects.reference[f]-a[f]-n.rects.popper[c],m=a[f]-n.rects.reference[f],g=O(i),y=g?"y"===f?g.clientHeight||0:g.clientWidth||0:0,b=h/2-m/2,x=p[l],w=y-u[c]-p[d],L=y/2-u[c]/2+b,M=ue(x,L,w),k=f;n.modifiersData[r]=((t={})[k]=M,t.centerOffset=M-L,t)}},effect:function(e){var t=e.state,n=e.options.element,r=void 0===n?"[data-popper-arrow]":n;null!=r&&("string"!=typeof r||(r=t.elements.popper.querySelector(r)))&&q(t.elements.popper,r)&&(t.elements.arrow=r)},requires:["popperOffsets"],requiresIfExists:["preventOverflow"]};function he(e,t,n){return void 0===n&&(n={x:0,y:0}),{top:e.top-t.height-n.y,right:e.right-t.width+n.x,bottom:e.bottom-t.height+n.y,left:e.left-t.width-n.x}}function me(e){return[j,D,E,A].some((function(t){return e[t]>=0}))}var ve={name:"hide",enabled:!0,phase:"main",requiresIfExists:["preventOverflow"],fn:function(e){var t=e.state,n=e.name,r=t.rects.reference,o=t.rects.popper,i=t.modifiersData.preventOverflow,a=Y(t,{elementContext:"reference"}),s=Y(t,{altBoundary:!0}),f=he(a,r),c=he(s,o,i),p=me(f),u=me(c);t.modifiersData[n]={referenceClippingOffsets:f,popperEscapeOffsets:c,isReferenceHidden:p,hasPopperEscaped:u},t.attributes.popper=Object.assign({},t.attributes.popper,{"data-popper-reference-hidden":p,"data-popper-escaped":u})}},ge=K({defaultModifiers:[Z,$,ne,re]}),ye=[Z,$,ne,re,oe,pe,le,de,ve],be=K({defaultModifiers:ye});e.applyStyles=re,e.arrow=de,e.computeStyles=ne,e.createPopper=be,e.createPopperLite=ge,e.defaultModifiers=ye,e.detectOverflow=Y,e.eventListeners=Z,e.flip=pe,e.hide=ve,e.offset=oe,e.popperGenerator=K,e.popperOffsets=$,e.preventOverflow=le,Object.defineProperty(e,"__esModule",{value:!0})})); 6 | 7 | -------------------------------------------------------------------------------- /README_files/libs/quarto-html/quarto-syntax-highlighting.css: -------------------------------------------------------------------------------- 1 | /* quarto syntax highlight colors */ 2 | :root { 3 | --quarto-hl-ot-color: #003B4F; 4 | --quarto-hl-at-color: #657422; 5 | --quarto-hl-ss-color: #20794D; 6 | --quarto-hl-an-color: #5E5E5E; 7 | --quarto-hl-fu-color: #4758AB; 8 | --quarto-hl-st-color: #20794D; 9 | --quarto-hl-cf-color: #003B4F; 10 | --quarto-hl-op-color: #5E5E5E; 11 | --quarto-hl-er-color: #AD0000; 12 | --quarto-hl-bn-color: #AD0000; 13 | --quarto-hl-al-color: #AD0000; 14 | --quarto-hl-va-color: #111111; 15 | --quarto-hl-bu-color: inherit; 16 | --quarto-hl-ex-color: inherit; 17 | --quarto-hl-pp-color: #AD0000; 18 | --quarto-hl-in-color: #5E5E5E; 19 | --quarto-hl-vs-color: #20794D; 20 | --quarto-hl-wa-color: #5E5E5E; 21 | --quarto-hl-do-color: #5E5E5E; 22 | --quarto-hl-im-color: #00769E; 23 | --quarto-hl-ch-color: #20794D; 24 | --quarto-hl-dt-color: #AD0000; 25 | --quarto-hl-fl-color: #AD0000; 26 | --quarto-hl-co-color: #5E5E5E; 27 | --quarto-hl-cv-color: #5E5E5E; 28 | --quarto-hl-cn-color: #8f5902; 29 | --quarto-hl-sc-color: #5E5E5E; 30 | --quarto-hl-dv-color: #AD0000; 31 | --quarto-hl-kw-color: #003B4F; 32 | } 33 | 34 | /* other quarto variables */ 35 | :root { 36 | --quarto-font-monospace: SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; 37 | } 38 | 39 | pre > code.sourceCode > span { 40 | color: #003B4F; 41 | } 42 | 43 | code span { 44 | color: #003B4F; 45 | } 46 | 47 | code.sourceCode > span { 48 | color: #003B4F; 49 | } 50 | 51 | div.sourceCode, 52 | div.sourceCode pre.sourceCode { 53 | color: #003B4F; 54 | } 55 | 56 | code span.ot { 57 | color: #003B4F; 58 | font-style: inherit; 59 | } 60 | 61 | code span.at { 62 | color: #657422; 63 | font-style: inherit; 64 | } 65 | 66 | code span.ss { 67 | color: #20794D; 68 | font-style: inherit; 69 | } 70 | 71 | code span.an { 72 | color: #5E5E5E; 73 | font-style: inherit; 74 | } 75 | 76 | code span.fu { 77 | color: #4758AB; 78 | font-style: inherit; 79 | } 80 | 81 | code span.st { 82 | color: #20794D; 83 | font-style: inherit; 84 | } 85 | 86 | code span.cf { 87 | color: #003B4F; 88 | font-style: inherit; 89 | } 90 | 91 | code span.op { 92 | color: #5E5E5E; 93 | font-style: inherit; 94 | } 95 | 96 | code span.er { 97 | color: #AD0000; 98 | font-style: inherit; 99 | } 100 | 101 | code span.bn { 102 | color: #AD0000; 103 | font-style: inherit; 104 | } 105 | 106 | code span.al { 107 | color: #AD0000; 108 | font-style: inherit; 109 | } 110 | 111 | code span.va { 112 | color: #111111; 113 | font-style: inherit; 114 | } 115 | 116 | code span.bu { 117 | font-style: inherit; 118 | } 119 | 120 | code span.ex { 121 | font-style: inherit; 122 | } 123 | 124 | code span.pp { 125 | color: #AD0000; 126 | font-style: inherit; 127 | } 128 | 129 | code span.in { 130 | color: #5E5E5E; 131 | font-style: inherit; 132 | } 133 | 134 | code span.vs { 135 | color: #20794D; 136 | font-style: inherit; 137 | } 138 | 139 | code span.wa { 140 | color: #5E5E5E; 141 | font-style: italic; 142 | } 143 | 144 | code span.do { 145 | color: #5E5E5E; 146 | font-style: italic; 147 | } 148 | 149 | code span.im { 150 | color: #00769E; 151 | font-style: inherit; 152 | } 153 | 154 | code span.ch { 155 | color: #20794D; 156 | font-style: inherit; 157 | } 158 | 159 | code span.dt { 160 | color: #AD0000; 161 | font-style: inherit; 162 | } 163 | 164 | code span.fl { 165 | color: #AD0000; 166 | font-style: inherit; 167 | } 168 | 169 | code span.co { 170 | color: #5E5E5E; 171 | font-style: inherit; 172 | } 173 | 174 | code span.cv { 175 | color: #5E5E5E; 176 | font-style: italic; 177 | } 178 | 179 | code span.cn { 180 | color: #8f5902; 181 | font-style: inherit; 182 | } 183 | 184 | code span.sc { 185 | color: #5E5E5E; 186 | font-style: inherit; 187 | } 188 | 189 | code span.dv { 190 | color: #AD0000; 191 | font-style: inherit; 192 | } 193 | 194 | code span.kw { 195 | color: #003B4F; 196 | font-style: inherit; 197 | } 198 | 199 | .prevent-inlining { 200 | content: ".tippy-arrow{bottom:0}.tippy-box[data-placement^=top]>.tippy-arrow:before{bottom:-7px;left:0;border-width:8px 8px 0;border-top-color:initial;transform-origin:center top}.tippy-box[data-placement^=bottom]>.tippy-arrow{top:0}.tippy-box[data-placement^=bottom]>.tippy-arrow:before{top:-7px;left:0;border-width:0 8px 8px;border-bottom-color:initial;transform-origin:center bottom}.tippy-box[data-placement^=left]>.tippy-arrow{right:0}.tippy-box[data-placement^=left]>.tippy-arrow:before{border-width:8px 0 8px 8px;border-left-color:initial;right:-7px;transform-origin:center left}.tippy-box[data-placement^=right]>.tippy-arrow{left:0}.tippy-box[data-placement^=right]>.tippy-arrow:before{left:-7px;border-width:8px 8px 8px 0;border-right-color:initial;transform-origin:center right}.tippy-box[data-inertia][data-state=visible]{transition-timing-function:cubic-bezier(.54,1.5,.38,1.11)}.tippy-arrow{width:16px;height:16px;color:#333}.tippy-arrow:before{content:"";position:absolute;border-color:transparent;border-style:solid}.tippy-content{position:relative;padding:5px 9px;z-index:1} -------------------------------------------------------------------------------- /README_files/libs/quarto-html/tippy.umd.min.js: -------------------------------------------------------------------------------- 1 | !function(e,t){"object"==typeof exports&&"undefined"!=typeof module?module.exports=t(require("@popperjs/core")):"function"==typeof define&&define.amd?define(["@popperjs/core"],t):(e=e||self).tippy=t(e.Popper)}(this,(function(e){"use strict";var t={passive:!0,capture:!0},n=function(){return document.body};function r(e,t,n){if(Array.isArray(e)){var r=e[t];return null==r?Array.isArray(n)?n[t]:n:r}return e}function o(e,t){var n={}.toString.call(e);return 0===n.indexOf("[object")&&n.indexOf(t+"]")>-1}function i(e,t){return"function"==typeof e?e.apply(void 0,t):e}function a(e,t){return 0===t?e:function(r){clearTimeout(n),n=setTimeout((function(){e(r)}),t)};var n}function s(e,t){var n=Object.assign({},e);return t.forEach((function(e){delete n[e]})),n}function u(e){return[].concat(e)}function c(e,t){-1===e.indexOf(t)&&e.push(t)}function p(e){return e.split("-")[0]}function f(e){return[].slice.call(e)}function l(e){return Object.keys(e).reduce((function(t,n){return void 0!==e[n]&&(t[n]=e[n]),t}),{})}function d(){return document.createElement("div")}function v(e){return["Element","Fragment"].some((function(t){return o(e,t)}))}function m(e){return o(e,"MouseEvent")}function g(e){return!(!e||!e._tippy||e._tippy.reference!==e)}function h(e){return v(e)?[e]:function(e){return o(e,"NodeList")}(e)?f(e):Array.isArray(e)?e:f(document.querySelectorAll(e))}function b(e,t){e.forEach((function(e){e&&(e.style.transitionDuration=t+"ms")}))}function y(e,t){e.forEach((function(e){e&&e.setAttribute("data-state",t)}))}function w(e){var t,n=u(e)[0];return null!=n&&null!=(t=n.ownerDocument)&&t.body?n.ownerDocument:document}function E(e,t,n){var r=t+"EventListener";["transitionend","webkitTransitionEnd"].forEach((function(t){e[r](t,n)}))}function O(e,t){for(var n=t;n;){var r;if(e.contains(n))return!0;n=null==n.getRootNode||null==(r=n.getRootNode())?void 0:r.host}return!1}var x={isTouch:!1},C=0;function T(){x.isTouch||(x.isTouch=!0,window.performance&&document.addEventListener("mousemove",A))}function A(){var e=performance.now();e-C<20&&(x.isTouch=!1,document.removeEventListener("mousemove",A)),C=e}function L(){var e=document.activeElement;if(g(e)){var t=e._tippy;e.blur&&!t.state.isVisible&&e.blur()}}var D=!!("undefined"!=typeof window&&"undefined"!=typeof document)&&!!window.msCrypto,R=Object.assign({appendTo:n,aria:{content:"auto",expanded:"auto"},delay:0,duration:[300,250],getReferenceClientRect:null,hideOnClick:!0,ignoreAttributes:!1,interactive:!1,interactiveBorder:2,interactiveDebounce:0,moveTransition:"",offset:[0,10],onAfterUpdate:function(){},onBeforeUpdate:function(){},onCreate:function(){},onDestroy:function(){},onHidden:function(){},onHide:function(){},onMount:function(){},onShow:function(){},onShown:function(){},onTrigger:function(){},onUntrigger:function(){},onClickOutside:function(){},placement:"top",plugins:[],popperOptions:{},render:null,showOnCreate:!1,touch:!0,trigger:"mouseenter focus",triggerTarget:null},{animateFill:!1,followCursor:!1,inlinePositioning:!1,sticky:!1},{allowHTML:!1,animation:"fade",arrow:!0,content:"",inertia:!1,maxWidth:350,role:"tooltip",theme:"",zIndex:9999}),k=Object.keys(R);function P(e){var t=(e.plugins||[]).reduce((function(t,n){var r,o=n.name,i=n.defaultValue;o&&(t[o]=void 0!==e[o]?e[o]:null!=(r=R[o])?r:i);return t}),{});return Object.assign({},e,t)}function j(e,t){var n=Object.assign({},t,{content:i(t.content,[e])},t.ignoreAttributes?{}:function(e,t){return(t?Object.keys(P(Object.assign({},R,{plugins:t}))):k).reduce((function(t,n){var r=(e.getAttribute("data-tippy-"+n)||"").trim();if(!r)return t;if("content"===n)t[n]=r;else try{t[n]=JSON.parse(r)}catch(e){t[n]=r}return t}),{})}(e,t.plugins));return n.aria=Object.assign({},R.aria,n.aria),n.aria={expanded:"auto"===n.aria.expanded?t.interactive:n.aria.expanded,content:"auto"===n.aria.content?t.interactive?null:"describedby":n.aria.content},n}function M(e,t){e.innerHTML=t}function V(e){var t=d();return!0===e?t.className="tippy-arrow":(t.className="tippy-svg-arrow",v(e)?t.appendChild(e):M(t,e)),t}function I(e,t){v(t.content)?(M(e,""),e.appendChild(t.content)):"function"!=typeof t.content&&(t.allowHTML?M(e,t.content):e.textContent=t.content)}function S(e){var t=e.firstElementChild,n=f(t.children);return{box:t,content:n.find((function(e){return e.classList.contains("tippy-content")})),arrow:n.find((function(e){return e.classList.contains("tippy-arrow")||e.classList.contains("tippy-svg-arrow")})),backdrop:n.find((function(e){return e.classList.contains("tippy-backdrop")}))}}function N(e){var t=d(),n=d();n.className="tippy-box",n.setAttribute("data-state","hidden"),n.setAttribute("tabindex","-1");var r=d();function o(n,r){var o=S(t),i=o.box,a=o.content,s=o.arrow;r.theme?i.setAttribute("data-theme",r.theme):i.removeAttribute("data-theme"),"string"==typeof r.animation?i.setAttribute("data-animation",r.animation):i.removeAttribute("data-animation"),r.inertia?i.setAttribute("data-inertia",""):i.removeAttribute("data-inertia"),i.style.maxWidth="number"==typeof r.maxWidth?r.maxWidth+"px":r.maxWidth,r.role?i.setAttribute("role",r.role):i.removeAttribute("role"),n.content===r.content&&n.allowHTML===r.allowHTML||I(a,e.props),r.arrow?s?n.arrow!==r.arrow&&(i.removeChild(s),i.appendChild(V(r.arrow))):i.appendChild(V(r.arrow)):s&&i.removeChild(s)}return r.className="tippy-content",r.setAttribute("data-state","hidden"),I(r,e.props),t.appendChild(n),n.appendChild(r),o(e.props,e.props),{popper:t,onUpdate:o}}N.$$tippy=!0;var B=1,H=[],U=[];function _(o,s){var v,g,h,C,T,A,L,k,M=j(o,Object.assign({},R,P(l(s)))),V=!1,I=!1,N=!1,_=!1,F=[],W=a(we,M.interactiveDebounce),X=B++,Y=(k=M.plugins).filter((function(e,t){return k.indexOf(e)===t})),$={id:X,reference:o,popper:d(),popperInstance:null,props:M,state:{isEnabled:!0,isVisible:!1,isDestroyed:!1,isMounted:!1,isShown:!1},plugins:Y,clearDelayTimeouts:function(){clearTimeout(v),clearTimeout(g),cancelAnimationFrame(h)},setProps:function(e){if($.state.isDestroyed)return;ae("onBeforeUpdate",[$,e]),be();var t=$.props,n=j(o,Object.assign({},t,l(e),{ignoreAttributes:!0}));$.props=n,he(),t.interactiveDebounce!==n.interactiveDebounce&&(ce(),W=a(we,n.interactiveDebounce));t.triggerTarget&&!n.triggerTarget?u(t.triggerTarget).forEach((function(e){e.removeAttribute("aria-expanded")})):n.triggerTarget&&o.removeAttribute("aria-expanded");ue(),ie(),J&&J(t,n);$.popperInstance&&(Ce(),Ae().forEach((function(e){requestAnimationFrame(e._tippy.popperInstance.forceUpdate)})));ae("onAfterUpdate",[$,e])},setContent:function(e){$.setProps({content:e})},show:function(){var e=$.state.isVisible,t=$.state.isDestroyed,o=!$.state.isEnabled,a=x.isTouch&&!$.props.touch,s=r($.props.duration,0,R.duration);if(e||t||o||a)return;if(te().hasAttribute("disabled"))return;if(ae("onShow",[$],!1),!1===$.props.onShow($))return;$.state.isVisible=!0,ee()&&(z.style.visibility="visible");ie(),de(),$.state.isMounted||(z.style.transition="none");if(ee()){var u=re(),p=u.box,f=u.content;b([p,f],0)}A=function(){var e;if($.state.isVisible&&!_){if(_=!0,z.offsetHeight,z.style.transition=$.props.moveTransition,ee()&&$.props.animation){var t=re(),n=t.box,r=t.content;b([n,r],s),y([n,r],"visible")}se(),ue(),c(U,$),null==(e=$.popperInstance)||e.forceUpdate(),ae("onMount",[$]),$.props.animation&&ee()&&function(e,t){me(e,t)}(s,(function(){$.state.isShown=!0,ae("onShown",[$])}))}},function(){var e,t=$.props.appendTo,r=te();e=$.props.interactive&&t===n||"parent"===t?r.parentNode:i(t,[r]);e.contains(z)||e.appendChild(z);$.state.isMounted=!0,Ce()}()},hide:function(){var e=!$.state.isVisible,t=$.state.isDestroyed,n=!$.state.isEnabled,o=r($.props.duration,1,R.duration);if(e||t||n)return;if(ae("onHide",[$],!1),!1===$.props.onHide($))return;$.state.isVisible=!1,$.state.isShown=!1,_=!1,V=!1,ee()&&(z.style.visibility="hidden");if(ce(),ve(),ie(!0),ee()){var i=re(),a=i.box,s=i.content;$.props.animation&&(b([a,s],o),y([a,s],"hidden"))}se(),ue(),$.props.animation?ee()&&function(e,t){me(e,(function(){!$.state.isVisible&&z.parentNode&&z.parentNode.contains(z)&&t()}))}(o,$.unmount):$.unmount()},hideWithInteractivity:function(e){ne().addEventListener("mousemove",W),c(H,W),W(e)},enable:function(){$.state.isEnabled=!0},disable:function(){$.hide(),$.state.isEnabled=!1},unmount:function(){$.state.isVisible&&$.hide();if(!$.state.isMounted)return;Te(),Ae().forEach((function(e){e._tippy.unmount()})),z.parentNode&&z.parentNode.removeChild(z);U=U.filter((function(e){return e!==$})),$.state.isMounted=!1,ae("onHidden",[$])},destroy:function(){if($.state.isDestroyed)return;$.clearDelayTimeouts(),$.unmount(),be(),delete o._tippy,$.state.isDestroyed=!0,ae("onDestroy",[$])}};if(!M.render)return $;var q=M.render($),z=q.popper,J=q.onUpdate;z.setAttribute("data-tippy-root",""),z.id="tippy-"+$.id,$.popper=z,o._tippy=$,z._tippy=$;var G=Y.map((function(e){return e.fn($)})),K=o.hasAttribute("aria-expanded");return he(),ue(),ie(),ae("onCreate",[$]),M.showOnCreate&&Le(),z.addEventListener("mouseenter",(function(){$.props.interactive&&$.state.isVisible&&$.clearDelayTimeouts()})),z.addEventListener("mouseleave",(function(){$.props.interactive&&$.props.trigger.indexOf("mouseenter")>=0&&ne().addEventListener("mousemove",W)})),$;function Q(){var e=$.props.touch;return Array.isArray(e)?e:[e,0]}function Z(){return"hold"===Q()[0]}function ee(){var e;return!(null==(e=$.props.render)||!e.$$tippy)}function te(){return L||o}function ne(){var e=te().parentNode;return e?w(e):document}function re(){return S(z)}function oe(e){return $.state.isMounted&&!$.state.isVisible||x.isTouch||C&&"focus"===C.type?0:r($.props.delay,e?0:1,R.delay)}function ie(e){void 0===e&&(e=!1),z.style.pointerEvents=$.props.interactive&&!e?"":"none",z.style.zIndex=""+$.props.zIndex}function ae(e,t,n){var r;(void 0===n&&(n=!0),G.forEach((function(n){n[e]&&n[e].apply(n,t)})),n)&&(r=$.props)[e].apply(r,t)}function se(){var e=$.props.aria;if(e.content){var t="aria-"+e.content,n=z.id;u($.props.triggerTarget||o).forEach((function(e){var r=e.getAttribute(t);if($.state.isVisible)e.setAttribute(t,r?r+" "+n:n);else{var o=r&&r.replace(n,"").trim();o?e.setAttribute(t,o):e.removeAttribute(t)}}))}}function ue(){!K&&$.props.aria.expanded&&u($.props.triggerTarget||o).forEach((function(e){$.props.interactive?e.setAttribute("aria-expanded",$.state.isVisible&&e===te()?"true":"false"):e.removeAttribute("aria-expanded")}))}function ce(){ne().removeEventListener("mousemove",W),H=H.filter((function(e){return e!==W}))}function pe(e){if(!x.isTouch||!N&&"mousedown"!==e.type){var t=e.composedPath&&e.composedPath()[0]||e.target;if(!$.props.interactive||!O(z,t)){if(u($.props.triggerTarget||o).some((function(e){return O(e,t)}))){if(x.isTouch)return;if($.state.isVisible&&$.props.trigger.indexOf("click")>=0)return}else ae("onClickOutside",[$,e]);!0===$.props.hideOnClick&&($.clearDelayTimeouts(),$.hide(),I=!0,setTimeout((function(){I=!1})),$.state.isMounted||ve())}}}function fe(){N=!0}function le(){N=!1}function de(){var e=ne();e.addEventListener("mousedown",pe,!0),e.addEventListener("touchend",pe,t),e.addEventListener("touchstart",le,t),e.addEventListener("touchmove",fe,t)}function ve(){var e=ne();e.removeEventListener("mousedown",pe,!0),e.removeEventListener("touchend",pe,t),e.removeEventListener("touchstart",le,t),e.removeEventListener("touchmove",fe,t)}function me(e,t){var n=re().box;function r(e){e.target===n&&(E(n,"remove",r),t())}if(0===e)return t();E(n,"remove",T),E(n,"add",r),T=r}function ge(e,t,n){void 0===n&&(n=!1),u($.props.triggerTarget||o).forEach((function(r){r.addEventListener(e,t,n),F.push({node:r,eventType:e,handler:t,options:n})}))}function he(){var e;Z()&&(ge("touchstart",ye,{passive:!0}),ge("touchend",Ee,{passive:!0})),(e=$.props.trigger,e.split(/\s+/).filter(Boolean)).forEach((function(e){if("manual"!==e)switch(ge(e,ye),e){case"mouseenter":ge("mouseleave",Ee);break;case"focus":ge(D?"focusout":"blur",Oe);break;case"focusin":ge("focusout",Oe)}}))}function be(){F.forEach((function(e){var t=e.node,n=e.eventType,r=e.handler,o=e.options;t.removeEventListener(n,r,o)})),F=[]}function ye(e){var t,n=!1;if($.state.isEnabled&&!xe(e)&&!I){var r="focus"===(null==(t=C)?void 0:t.type);C=e,L=e.currentTarget,ue(),!$.state.isVisible&&m(e)&&H.forEach((function(t){return t(e)})),"click"===e.type&&($.props.trigger.indexOf("mouseenter")<0||V)&&!1!==$.props.hideOnClick&&$.state.isVisible?n=!0:Le(e),"click"===e.type&&(V=!n),n&&!r&&De(e)}}function we(e){var t=e.target,n=te().contains(t)||z.contains(t);"mousemove"===e.type&&n||function(e,t){var n=t.clientX,r=t.clientY;return e.every((function(e){var t=e.popperRect,o=e.popperState,i=e.props.interactiveBorder,a=p(o.placement),s=o.modifiersData.offset;if(!s)return!0;var u="bottom"===a?s.top.y:0,c="top"===a?s.bottom.y:0,f="right"===a?s.left.x:0,l="left"===a?s.right.x:0,d=t.top-r+u>i,v=r-t.bottom-c>i,m=t.left-n+f>i,g=n-t.right-l>i;return d||v||m||g}))}(Ae().concat(z).map((function(e){var t,n=null==(t=e._tippy.popperInstance)?void 0:t.state;return n?{popperRect:e.getBoundingClientRect(),popperState:n,props:M}:null})).filter(Boolean),e)&&(ce(),De(e))}function Ee(e){xe(e)||$.props.trigger.indexOf("click")>=0&&V||($.props.interactive?$.hideWithInteractivity(e):De(e))}function Oe(e){$.props.trigger.indexOf("focusin")<0&&e.target!==te()||$.props.interactive&&e.relatedTarget&&z.contains(e.relatedTarget)||De(e)}function xe(e){return!!x.isTouch&&Z()!==e.type.indexOf("touch")>=0}function Ce(){Te();var t=$.props,n=t.popperOptions,r=t.placement,i=t.offset,a=t.getReferenceClientRect,s=t.moveTransition,u=ee()?S(z).arrow:null,c=a?{getBoundingClientRect:a,contextElement:a.contextElement||te()}:o,p=[{name:"offset",options:{offset:i}},{name:"preventOverflow",options:{padding:{top:2,bottom:2,left:5,right:5}}},{name:"flip",options:{padding:5}},{name:"computeStyles",options:{adaptive:!s}},{name:"$$tippy",enabled:!0,phase:"beforeWrite",requires:["computeStyles"],fn:function(e){var t=e.state;if(ee()){var n=re().box;["placement","reference-hidden","escaped"].forEach((function(e){"placement"===e?n.setAttribute("data-placement",t.placement):t.attributes.popper["data-popper-"+e]?n.setAttribute("data-"+e,""):n.removeAttribute("data-"+e)})),t.attributes.popper={}}}}];ee()&&u&&p.push({name:"arrow",options:{element:u,padding:3}}),p.push.apply(p,(null==n?void 0:n.modifiers)||[]),$.popperInstance=e.createPopper(c,z,Object.assign({},n,{placement:r,onFirstUpdate:A,modifiers:p}))}function Te(){$.popperInstance&&($.popperInstance.destroy(),$.popperInstance=null)}function Ae(){return f(z.querySelectorAll("[data-tippy-root]"))}function Le(e){$.clearDelayTimeouts(),e&&ae("onTrigger",[$,e]),de();var t=oe(!0),n=Q(),r=n[0],o=n[1];x.isTouch&&"hold"===r&&o&&(t=o),t?v=setTimeout((function(){$.show()}),t):$.show()}function De(e){if($.clearDelayTimeouts(),ae("onUntrigger",[$,e]),$.state.isVisible){if(!($.props.trigger.indexOf("mouseenter")>=0&&$.props.trigger.indexOf("click")>=0&&["mouseleave","mousemove"].indexOf(e.type)>=0&&V)){var t=oe(!1);t?g=setTimeout((function(){$.state.isVisible&&$.hide()}),t):h=requestAnimationFrame((function(){$.hide()}))}}else ve()}}function F(e,n){void 0===n&&(n={});var r=R.plugins.concat(n.plugins||[]);document.addEventListener("touchstart",T,t),window.addEventListener("blur",L);var o=Object.assign({},n,{plugins:r}),i=h(e).reduce((function(e,t){var n=t&&_(t,o);return n&&e.push(n),e}),[]);return v(e)?i[0]:i}F.defaultProps=R,F.setDefaultProps=function(e){Object.keys(e).forEach((function(t){R[t]=e[t]}))},F.currentInput=x;var W=Object.assign({},e.applyStyles,{effect:function(e){var t=e.state,n={popper:{position:t.options.strategy,left:"0",top:"0",margin:"0"},arrow:{position:"absolute"},reference:{}};Object.assign(t.elements.popper.style,n.popper),t.styles=n,t.elements.arrow&&Object.assign(t.elements.arrow.style,n.arrow)}}),X={mouseover:"mouseenter",focusin:"focus",click:"click"};var Y={name:"animateFill",defaultValue:!1,fn:function(e){var t;if(null==(t=e.props.render)||!t.$$tippy)return{};var n=S(e.popper),r=n.box,o=n.content,i=e.props.animateFill?function(){var e=d();return e.className="tippy-backdrop",y([e],"hidden"),e}():null;return{onCreate:function(){i&&(r.insertBefore(i,r.firstElementChild),r.setAttribute("data-animatefill",""),r.style.overflow="hidden",e.setProps({arrow:!1,animation:"shift-away"}))},onMount:function(){if(i){var e=r.style.transitionDuration,t=Number(e.replace("ms",""));o.style.transitionDelay=Math.round(t/10)+"ms",i.style.transitionDuration=e,y([i],"visible")}},onShow:function(){i&&(i.style.transitionDuration="0ms")},onHide:function(){i&&y([i],"hidden")}}}};var $={clientX:0,clientY:0},q=[];function z(e){var t=e.clientX,n=e.clientY;$={clientX:t,clientY:n}}var J={name:"followCursor",defaultValue:!1,fn:function(e){var t=e.reference,n=w(e.props.triggerTarget||t),r=!1,o=!1,i=!0,a=e.props;function s(){return"initial"===e.props.followCursor&&e.state.isVisible}function u(){n.addEventListener("mousemove",f)}function c(){n.removeEventListener("mousemove",f)}function p(){r=!0,e.setProps({getReferenceClientRect:null}),r=!1}function f(n){var r=!n.target||t.contains(n.target),o=e.props.followCursor,i=n.clientX,a=n.clientY,s=t.getBoundingClientRect(),u=i-s.left,c=a-s.top;!r&&e.props.interactive||e.setProps({getReferenceClientRect:function(){var e=t.getBoundingClientRect(),n=i,r=a;"initial"===o&&(n=e.left+u,r=e.top+c);var s="horizontal"===o?e.top:r,p="vertical"===o?e.right:n,f="horizontal"===o?e.bottom:r,l="vertical"===o?e.left:n;return{width:p-l,height:f-s,top:s,right:p,bottom:f,left:l}}})}function l(){e.props.followCursor&&(q.push({instance:e,doc:n}),function(e){e.addEventListener("mousemove",z)}(n))}function d(){0===(q=q.filter((function(t){return t.instance!==e}))).filter((function(e){return e.doc===n})).length&&function(e){e.removeEventListener("mousemove",z)}(n)}return{onCreate:l,onDestroy:d,onBeforeUpdate:function(){a=e.props},onAfterUpdate:function(t,n){var i=n.followCursor;r||void 0!==i&&a.followCursor!==i&&(d(),i?(l(),!e.state.isMounted||o||s()||u()):(c(),p()))},onMount:function(){e.props.followCursor&&!o&&(i&&(f($),i=!1),s()||u())},onTrigger:function(e,t){m(t)&&($={clientX:t.clientX,clientY:t.clientY}),o="focus"===t.type},onHidden:function(){e.props.followCursor&&(p(),c(),i=!0)}}}};var G={name:"inlinePositioning",defaultValue:!1,fn:function(e){var t,n=e.reference;var r=-1,o=!1,i=[],a={name:"tippyInlinePositioning",enabled:!0,phase:"afterWrite",fn:function(o){var a=o.state;e.props.inlinePositioning&&(-1!==i.indexOf(a.placement)&&(i=[]),t!==a.placement&&-1===i.indexOf(a.placement)&&(i.push(a.placement),e.setProps({getReferenceClientRect:function(){return function(e){return function(e,t,n,r){if(n.length<2||null===e)return t;if(2===n.length&&r>=0&&n[0].left>n[1].right)return n[r]||t;switch(e){case"top":case"bottom":var o=n[0],i=n[n.length-1],a="top"===e,s=o.top,u=i.bottom,c=a?o.left:i.left,p=a?o.right:i.right;return{top:s,bottom:u,left:c,right:p,width:p-c,height:u-s};case"left":case"right":var f=Math.min.apply(Math,n.map((function(e){return e.left}))),l=Math.max.apply(Math,n.map((function(e){return e.right}))),d=n.filter((function(t){return"left"===e?t.left===f:t.right===l})),v=d[0].top,m=d[d.length-1].bottom;return{top:v,bottom:m,left:f,right:l,width:l-f,height:m-v};default:return t}}(p(e),n.getBoundingClientRect(),f(n.getClientRects()),r)}(a.placement)}})),t=a.placement)}};function s(){var t;o||(t=function(e,t){var n;return{popperOptions:Object.assign({},e.popperOptions,{modifiers:[].concat(((null==(n=e.popperOptions)?void 0:n.modifiers)||[]).filter((function(e){return e.name!==t.name})),[t])})}}(e.props,a),o=!0,e.setProps(t),o=!1)}return{onCreate:s,onAfterUpdate:s,onTrigger:function(t,n){if(m(n)){var o=f(e.reference.getClientRects()),i=o.find((function(e){return e.left-2<=n.clientX&&e.right+2>=n.clientX&&e.top-2<=n.clientY&&e.bottom+2>=n.clientY})),a=o.indexOf(i);r=a>-1?a:r}},onHidden:function(){r=-1}}}};var K={name:"sticky",defaultValue:!1,fn:function(e){var t=e.reference,n=e.popper;function r(t){return!0===e.props.sticky||e.props.sticky===t}var o=null,i=null;function a(){var s=r("reference")?(e.popperInstance?e.popperInstance.state.elements.reference:t).getBoundingClientRect():null,u=r("popper")?n.getBoundingClientRect():null;(s&&Q(o,s)||u&&Q(i,u))&&e.popperInstance&&e.popperInstance.update(),o=s,i=u,e.state.isMounted&&requestAnimationFrame(a)}return{onMount:function(){e.props.sticky&&a()}}}};function Q(e,t){return!e||!t||(e.top!==t.top||e.right!==t.right||e.bottom!==t.bottom||e.left!==t.left)}return F.setDefaultProps({plugins:[Y,J,G,K],render:N}),F.createSingleton=function(e,t){var n;void 0===t&&(t={});var r,o=e,i=[],a=[],c=t.overrides,p=[],f=!1;function l(){a=o.map((function(e){return u(e.props.triggerTarget||e.reference)})).reduce((function(e,t){return e.concat(t)}),[])}function v(){i=o.map((function(e){return e.reference}))}function m(e){o.forEach((function(t){e?t.enable():t.disable()}))}function g(e){return o.map((function(t){var n=t.setProps;return t.setProps=function(o){n(o),t.reference===r&&e.setProps(o)},function(){t.setProps=n}}))}function h(e,t){var n=a.indexOf(t);if(t!==r){r=t;var s=(c||[]).concat("content").reduce((function(e,t){return e[t]=o[n].props[t],e}),{});e.setProps(Object.assign({},s,{getReferenceClientRect:"function"==typeof s.getReferenceClientRect?s.getReferenceClientRect:function(){var e;return null==(e=i[n])?void 0:e.getBoundingClientRect()}}))}}m(!1),v(),l();var b={fn:function(){return{onDestroy:function(){m(!0)},onHidden:function(){r=null},onClickOutside:function(e){e.props.showOnCreate&&!f&&(f=!0,r=null)},onShow:function(e){e.props.showOnCreate&&!f&&(f=!0,h(e,i[0]))},onTrigger:function(e,t){h(e,t.currentTarget)}}}},y=F(d(),Object.assign({},s(t,["overrides"]),{plugins:[b].concat(t.plugins||[]),triggerTarget:a,popperOptions:Object.assign({},t.popperOptions,{modifiers:[].concat((null==(n=t.popperOptions)?void 0:n.modifiers)||[],[W])})})),w=y.show;y.show=function(e){if(w(),!r&&null==e)return h(y,i[0]);if(!r||null!=e){if("number"==typeof e)return i[e]&&h(y,i[e]);if(o.indexOf(e)>=0){var t=e.reference;return h(y,t)}return i.indexOf(e)>=0?h(y,e):void 0}},y.showNext=function(){var e=i[0];if(!r)return y.show(0);var t=i.indexOf(r);y.show(i[t+1]||e)},y.showPrevious=function(){var e=i[i.length-1];if(!r)return y.show(e);var t=i.indexOf(r),n=i[t-1]||e;y.show(n)};var E=y.setProps;return y.setProps=function(e){c=e.overrides||c,E(e)},y.setInstances=function(e){m(!0),p.forEach((function(e){return e()})),o=e,m(!1),v(),l(),p=g(y),y.setProps({triggerTarget:a})},p=g(y),y},F.delegate=function(e,n){var r=[],o=[],i=!1,a=n.target,c=s(n,["target"]),p=Object.assign({},c,{trigger:"manual",touch:!1}),f=Object.assign({touch:R.touch},c,{showOnCreate:!0}),l=F(e,p);function d(e){if(e.target&&!i){var t=e.target.closest(a);if(t){var r=t.getAttribute("data-tippy-trigger")||n.trigger||R.trigger;if(!t._tippy&&!("touchstart"===e.type&&"boolean"==typeof f.touch||"touchstart"!==e.type&&r.indexOf(X[e.type])<0)){var s=F(t,f);s&&(o=o.concat(s))}}}}function v(e,t,n,o){void 0===o&&(o=!1),e.addEventListener(t,n,o),r.push({node:e,eventType:t,handler:n,options:o})}return u(l).forEach((function(e){var n=e.destroy,a=e.enable,s=e.disable;e.destroy=function(e){void 0===e&&(e=!0),e&&o.forEach((function(e){e.destroy()})),o=[],r.forEach((function(e){var t=e.node,n=e.eventType,r=e.handler,o=e.options;t.removeEventListener(n,r,o)})),r=[],n()},e.enable=function(){a(),o.forEach((function(e){return e.enable()})),i=!1},e.disable=function(){s(),o.forEach((function(e){return e.disable()})),i=!0},function(e){var n=e.reference;v(n,"touchstart",d,t),v(n,"mouseover",d),v(n,"focusin",d),v(n,"click",d)}(e)})),l},F.hideAll=function(e){var t=void 0===e?{}:e,n=t.exclude,r=t.duration;U.forEach((function(e){var t=!1;if(n&&(t=g(n)?e.reference===n:e.popper===n.popper),!t){var o=e.props.duration;e.setProps({duration:r}),e.hide(),e.state.isDestroyed||e.setProps({duration:o})}}))},F.roundArrow='',F})); 2 | 3 | -------------------------------------------------------------------------------- /configs/default_config.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "/home/user/data/" #"/home/malbrank/nas/protmamba2/data/" and "/nvme1/common/OpenProteinSet/" 2 | output_dir: "/home/user/results/" 3 | namedir: "test0" 4 | train_dataset_path: "encoded_MSAs_subset-100.pkl" 5 | eval_dataset_path: '...' 6 | batch_size: 2 # mamba trained with total 0.5M tokens per batch 7 | d_model: 1024 8 | gradient_accumulation_steps: 1 9 | checkpoint_mixer: False 10 | learning_rate: 0.0006 # mamba default (x5), decrease to 0.0006 for sizes > 100M params and 0.0002 for sizes > 1B params 11 | weight_decay: 0.1 # mamba default 12 | beta1: 0.9 # mamba default 13 | beta2: 0.95 # mamba default 14 | max_grad_norm: 1. # mamba default 15 | warmup_steps: 500 # change accordingly to the number of epochs (500) 16 | scheduler: "constant" # "cosine" # "cosine-restarts" # "constant" 17 | num_cycles: 1 18 | n_layer: 16 19 | num_epochs: 100 20 | residual_in_fp32: False 21 | vocab_size: 38 22 | sample_sequences: False # set to false to have all sequences in the batch to be of the same length given by max_msa_len 23 | max_msa_len: 131072 #32768 #65536 #131072 #262144 #524288 #1048576 24 | seed_sequence_sampling: 42 25 | seed_datasets: 0 26 | save_steps: 250 27 | eval_steps: 50 28 | size_validation: 4 29 | logging_steps: 10 30 | eval_accumulation_steps: 10 31 | save_total_limit: 50 32 | dtype: "bfloat16" 33 | fim_strategy: "multiple_span" #["no-scramble", "one_span", "multiple_span"] 34 | compute_only_fim_loss: False 35 | always_mask: False 36 | max_position_embeddings: 2048 37 | max_seq_position_embeddings: 512 38 | add_position_ids: "1d" #["none, "1d", "2d"] 39 | reverse: False # when true trains the model in a bidirectional way 40 | early_stopping_metric: "eval_train_loss" # Name of metric to use for early stopping 41 | patience: 10 # how many evaluations to wait before restarting training before a spike 42 | loss_increase_factor: 1.005 # how much the loss can increase (for patience iterations) before restarting training 43 | -------------------------------------------------------------------------------- /configs/default_config_fim-finetuning.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "" 2 | output_dir: "results/" 3 | namedir: "test0" 4 | train_dataset_path: "encoded_MSAs_train.pkl" 5 | eval_dataset_path: '...' 6 | batch_size: 4 # mamba trained with total 0.5M tokens per batch 7 | d_model: 1024 8 | gradient_accumulation_steps: 4 9 | learning_rate: 0.0001 # mamba default (x5), decrease to 0.0006 for sizes > 100M params and 0.0002 for sizes > 1B params 10 | weight_decay: 0.1 # mamba default 11 | beta1: 0.9 # mamba default 12 | beta2: 0.95 # mamba default 13 | max_grad_norm: 1. # mamba default 14 | warmup_steps: 500 # change accordingly to the number of epochs (500) 15 | scheduler: "constant" # "cosine" # "cosine-restarts" # "constant" 16 | num_cycles: 1 17 | n_layer: 16 18 | num_epochs: 100 19 | residual_in_fp32: False 20 | vocab_size: 38 21 | sample_sequences: False # set to false to have all sequences in the batch to be of the same length given by max_msa_len 22 | max_msa_len: 32768 23 | seed_sequence_sampling: 42 24 | seed_datasets: 0 25 | save_steps: 250 26 | eval_steps: 50 27 | size_validation: 192 28 | logging_steps: 10 29 | eval_accumulation_steps: 200 30 | save_total_limit: 50 31 | dtype: "bfloat16" 32 | fim_strategy: "multiple_span" #["no-scramble", "one_span", "multiple_span"] 33 | compute_only_fim_loss: True 34 | always_mask: True 35 | max_position_embeddings: 2048 36 | max_seq_position_embeddings: 512 37 | add_position_ids: "1d" #["none, "1d", "2d"] 38 | reverse: False # when true trains the model in a bidirectional way 39 | early_stopping_metric: "eval_train_loss" # Name of metric to use for early stopping 40 | patience: 10 # how many evaluations to wait before restarting training before a spike 41 | loss_increase_factor: 1.005 # how much the loss can increase (for patience iterations) before restarting training 42 | -------------------------------------------------------------------------------- /data/cluster_testing_set.txt: -------------------------------------------------------------------------------- 1 | A0A2H9MP70 2 | A0A135YUE9 3 | A0A0A0HZM8 4 | G4ZH78 5 | A0A091TDH7 6 | A0A2N1P554 7 | A0A1C5UJ41 8 | A0A1S7SAF9 9 | A0A194V424 10 | S7UZ45 11 | A0A1A8YWK1 12 | A0A1X0RUL2 13 | D3BPE1 14 | F2CV06 15 | A0A146ZGL6 16 | A0A1W9GDS7 17 | D8SD16 18 | A0A074SV30 19 | A0A139IN77 20 | G3PS96 21 | A0A1C6Q5J2 22 | I4B642 23 | A0A0D2P8K3 24 | A0A2X4BAY2 25 | A0A241VGM5 26 | A0A1S3G530 27 | A0A2V9BMA4 28 | A0A061CUF5 29 | A0A1Q6IFI4 30 | A0A2H3XLB1 31 | A0A2H1K104 32 | A0A0N8D6U2 33 | A0A2H0SY85 34 | A0A1Q8BKF6 35 | A0A011KIR2 36 | A0A2T6IE20 37 | A0A0U5GNP4 38 | G4Q7R9 39 | A0A015U815 40 | A0A1Z9SKV7 41 | M2V802 42 | A0A2N0GFX8 43 | A0A1M6W940 44 | A0A0D3E6J8 45 | A8LB25 46 | A0A1I8GHI4 47 | A0A139DH55 48 | A0A2E2E1S9 49 | A0A1W4XJQ4 50 | A0A1X6N5E6 51 | A0A2P7X7N0 52 | A0A259P0W6 53 | A0A100ZK86 54 | A0A0J6YQT2 55 | A0A2T7I8M1 56 | A0A183FD52 57 | A0A034U5W5 58 | A0A0D3EBJ3 59 | V4TTQ9 60 | A0A2G8KN10 61 | A0A1C5Y417 62 | D7E4W2 63 | A0A140DUU2 64 | A0A2E7R4T3 65 | A0A0Q8G9Y1 66 | A0A151W2C9 67 | A0A1S6RA81 68 | A0A2X2V9W3 69 | G4UA16 70 | A0A1B1TTV6 71 | A0A1J4VTS3 72 | F2QXK6 73 | D4TP75 74 | F7K4I2 75 | A0A146UKP3 76 | I9HJ82 77 | A0A1C2B5Y2 78 | R6LWQ2 79 | G2XEX7 80 | G0RV91 81 | A0A1E1LS16 82 | D0RHH1 83 | A0A218U7P4 84 | A0A2T2XQJ3 85 | A0A146WDL2 86 | I0H2N4 87 | A0A2E5HYH0 88 | A0A2H9R4W5 89 | U6K2X1 90 | A0A0E0PQF6 91 | A0A090D524 92 | A0A062WSU5 93 | A0A2P4KZX6 94 | A0A1Q8RJR0 95 | A0A0F5EXM8 96 | F5SFN2 97 | A0A164Y326 98 | A0A183MWN1 99 | A0A1B6I906 100 | A0A174GT43 101 | A0A2E5WIK0 102 | A0A2T3JKA7 103 | A0A1V8V259 104 | D8LRZ7 105 | A0A1Z5KJJ0 106 | A0A218ZS09 107 | A0A2V0PDA7 108 | A0A2B0DDL2 109 | A0A016T3D8 110 | A0A0W1GJH6 111 | A0A161PNY9 112 | A0A0U5JV97 113 | A0A1F7FT28 114 | A0A229FXK9 115 | A0A177E5N8 116 | A0A0E0FT38 117 | A0A1A8NCQ5 118 | A0A177VRQ4 119 | M7W498 120 | A0A2S5UT71 121 | A0A0R3TC99 122 | G3KJ89 123 | A0A0G0UUM0 124 | A0A2G9PQG2 125 | A0A1E3JV34 126 | A0A0L9T670 127 | A0A0P5XSR0 128 | X5PQZ4 129 | A0A133UV38 130 | U6KSW1 131 | A0A250XFE0 132 | V9E066 133 | A0A1I8IAA2 134 | A0A2S0I6F1 135 | A0A1Y3MT81 136 | A0C7Z7 137 | A0A117LYF9 138 | A0A0R0LYW2 139 | A0A087Y2J8 140 | A0A0L9UQX2 141 | R6X4Z8 142 | A0A067D152 143 | A0A2X0NAD5 144 | A0A1Z9LZ39 145 | F8SJW2 146 | A0A0H2YQY1 147 | A0A0K1Q686 148 | A0A0D2JXS7 149 | H2XLS8 150 | A0A0G2HN58 151 | A0A1Y6LBW9 152 | A0A1Z8A0W0 153 | F0XWE8 154 | A0A1M6LQ61 155 | A0A1I7KZL4 156 | C6HLX2 157 | A0A0G1V3T1 158 | A0A0W4ZUN9 159 | A0A2U0X2E7 160 | A0A1A7X1L9 161 | G3TV23 162 | A0A094GDL8 163 | A0A1H8FEC2 164 | A0A117JSE9 165 | A0A1C4QES5 166 | A0A1X4NHW8 167 | A0A1G2VBY4 168 | A0A086ZSM7 169 | A0A0A1NEM0 170 | A0A1X4IEC0 171 | A0A0D0P8L0 172 | Q7UQT2 173 | A8IHJ2 174 | A0A0D0VSE8 175 | A0A0P5YT10 176 | A0A1I6ILZ0 177 | A0A1G1QDR2 178 | A0A1H7IBB8 179 | A0A1Y4LBX6 180 | A0A2K5R315 181 | A0A1E1M9N5 182 | D3TKQ3 183 | A0A084SXL2 184 | A0A1D2MM58 185 | A0A131Y5V2 186 | A0A2T4ZF09 187 | A0A1V3ZYM8 188 | A0A1Y2A551 189 | I3EL87 190 | A0A2M8FD39 191 | A0A2D5FE50 192 | A0A161M099 193 | A0A1I5FW45 194 | R6EHP3 195 | A0A2D5TUR6 196 | A0A078KS57 197 | A0A1B9HC20 198 | A0A267C2G1 199 | B9GCB8 200 | R5DHY1 201 | A0A0J7YM56 202 | R5J7I3 203 | A0A1G4H1G6 204 | A0A1F7W895 205 | X0QGS9 206 | A0A2H4PHK3 207 | F1VX56 208 | B8PAF7 209 | A0A081AFH5 210 | D8R8M8 211 | A0A2M7WXI8 212 | A0A1G4IBP5 213 | A0A0P7UCZ3 214 | A0C7M5 215 | A0A2G5CEQ2 216 | R6BV15 217 | A0A177T9N4 218 | A0A0Q7AYE5 219 | A0A1H6K6C9 220 | G8DIF3 221 | S9PQ80 222 | A0A0F2JH67 223 | K0SQA1 224 | A0A0L0NZU6 225 | A0A177EK06 226 | A0A2P2H883 227 | F4QF76 228 | A0A0D8XHY8 229 | Q86AH3 230 | Q1YY24 231 | A0A199VAQ7 232 | U6L8G8 233 | A0A0V0QSF5 234 | F7A9G8 235 | A0A2N5QTU0 236 | A8IR63 237 | A0A0R2CJ20 238 | F2TU83 239 | A8J2V5 240 | R4M0P2 241 | S8AJS2 242 | A0A1F7PX91 243 | J0FG66 244 | A0A0B8PC95 245 | A0A2I1YCH7 246 | A0A2C9W9U0 247 | A0A1B6FK08 248 | A0A2G6EAR2 249 | A0A1I3F8T1 250 | A0A1Q9RTE7 251 | A0A0P0V5W8 252 | A0A1V0SAR2 253 | A0A1I2BZ18 254 | A0A2E3UF32 255 | A0A246Q562 256 | A0A269PK80 257 | A0A1G3B6C5 258 | A0A1Y3NGB3 259 | A0A2E5AKD3 260 | A0A2D5V006 261 | A0A2G5THJ0 262 | A0A1J4XS40 263 | A0A2H0KQH7 264 | A0A0L0D2J7 265 | A0A1V6IN80 266 | A0A1S0YL97 267 | A0A0F4XBA6 268 | A0A2E1H1V3 269 | A0A2A9JXQ4 270 | R6VTZ4 271 | A0A1H7Q3C2 272 | I9NVR4 273 | A0A2N0MVG7 274 | A0A0Q8IT99 275 | A0A146ZG79 276 | A0A1I7Z642 277 | A0A2U1KXA0 278 | A0A0G4ARW6 279 | A0A2B7YHU6 280 | A0A0D0WQM2 281 | C4TAU8 282 | E6H3H0 283 | A0A1Z5KG41 284 | A0A1D8EQA4 285 | A0A1B7TJ72 286 | A0A2D0PI02 287 | A0A2M7UVW4 288 | A0A0Q8EKA8 289 | A0A1V3WHB7 290 | A0A0E0R2T7 291 | A0A1Z9JVZ8 292 | A0A081AN11 293 | A0A165JU86 294 | A0A0A8YMT6 295 | A0A2S6A3G9 296 | A0A2W7QA23 297 | A0A2E4GD30 298 | A0A287HDC6 299 | A0A2V2WA98 300 | A0A2V6UJR7 301 | A0A1A9VNJ2 302 | C9SPP8 303 | A0A1C0ANX6 304 | A0A2N5QK37 305 | A0A2N0CCW7 306 | A0A1I6CNT2 307 | A0A1A9XQ05 308 | A0A2E9F4L3 309 | A0A1V5V240 310 | A0L9U3 311 | H1BFP4 312 | A1L2B3 313 | A0A0G2ZJV3 314 | J9Z942 315 | A0A2P2D9K3 316 | C4N3X9 317 | A0A094W8K2 318 | I3UFV1 319 | H2KUH3 320 | A3FQ58 321 | A0A061RIL5 322 | A0A2N7S696 323 | A0A1S3RUH0 324 | A0A260DS07 325 | A4SKL9 326 | A0A1G1IN71 327 | A0A2W1PC49 328 | A0A139A9J1 329 | A0A2E2JYJ1 330 | A0A2R6HRX3 331 | A0A183W649 332 | A0A2P8YRR6 333 | A0A231NZC7 334 | A0A1R1XAL6 335 | A0A0J6DNK1 336 | A0A1Y3PLN6 337 | A0A0G0RHK2 338 | A0A1E5NF60 339 | A0A2E2CPP0 340 | G2ZU27 341 | A0A0D6W3I0 342 | A0A2K2DML4 343 | Q7G6C7 344 | A0A084QH61 345 | G3JWZ8 346 | A0A084EP39 347 | A0A0S2W5N8 348 | A0A0V0W6J4 349 | A0A2E4PDR8 350 | R5NLJ1 351 | A0A0P5DT33 352 | A0A2N0X0R7 353 | H9YTY1 354 | A0A060DWD2 355 | A0A1Z5KMJ9 356 | A0A0J6VGC2 357 | A0A1H0PHT4 358 | V7ZH71 359 | T1YB19 360 | A0A1R3FLW4 361 | A0A063YAQ3 362 | A0A1H3TRJ9 363 | A0A0D6YLC5 364 | K7H947 365 | A0A2H9PIS1 366 | A0A2I0Q506 367 | A0A286Y095 368 | A0A0W0YUV2 369 | M1I5J2 370 | A0A2E4DPY6 371 | R6SJ41 372 | R1GSM3 373 | A0A0S4JHC7 374 | A0A0Q8QUQ9 375 | A0A182KZ14 376 | M2XUY8 377 | A0A061QKP3 378 | A0A133L101 379 | A0A1Q7M6I5 380 | A0A146XUL4 381 | A0A1R4LTT6 382 | A0A1Q6QLJ3 383 | C4Y8P3 384 | A0A127QZ96 385 | A0A1B7IIT3 386 | A0A151ZB98 387 | A0A160DHP5 388 | W9DF34 389 | A0A1Q6Q3G6 390 | A0A2W1S5F0 391 | A0A2T7LX00 392 | F2SYC7 393 | A0A0A1D706 394 | A0A2X0Z8U1 395 | A0A2S7AZJ2 396 | R9I0Y7 397 | A0A1W0WJL4 398 | A0A2A4GU25 399 | A0A2A4PKH6 400 | A0A2D7NK98 401 | D8PI39 402 | A0A2N6K632 403 | A0A1Z9WT56 404 | A0A2N5AKR1 405 | A0A1W1HT35 406 | E9RI77 407 | A0A1A9Y1U5 408 | A0A0P5M8M0 409 | H1WEP7 410 | J1I4D2 411 | A0A0G0DXD2 412 | A0A1G2XVN6 413 | A0A2M7SX09 414 | T2NAV1 415 | A0A2B4S0U5 416 | A0A1D6CKH0 417 | A0A212CFE3 418 | A0A224YVB0 419 | A0A246J319 420 | A0A1C6AWA5 421 | A0A2N8KT16 422 | A0A1Q9F2X9 423 | A0A0H5S9S5 424 | A0A1V6YYP9 425 | A0A2G2WNH7 426 | A0A1I0H4D8 427 | C9RE15 428 | A0A226BG84 429 | A8I5J8 430 | A0A177V663 431 | A0A0D2KAE6 432 | A0A2G3A525 433 | U6K9L1 434 | A4HYW0 435 | G9QSI5 436 | A0A2V9DLZ0 437 | A0A1Q6KFE5 438 | A0A2K3K8S5 439 | A0A2T7HXY0 440 | U5GZ84 441 | A0A0C9XXW7 442 | A6C913 443 | R8F7I6 444 | A0A1C5JQ62 445 | T1WCQ3 446 | A0A131Z479 447 | A0A0M9IJ34 448 | A0A2P1CNG4 449 | A0A146FZQ4 450 | A0A1C0TV81 451 | S4XXA3 452 | A0A1C0VB49 453 | A0A2N6PBQ2 454 | A0A0P4YV14 455 | A0A1Y2C9D7 456 | A0A2D8TKL5 457 | A0A2N1PRQ9 458 | A0A258R635 459 | A0A1X9SID3 460 | X7XVC5 461 | A0A1S8VDG2 462 | A0A1A8XDV6 463 | A0A1M6VHQ8 464 | C4T9N1 465 | A0A2S4UHR8 466 | A0A2E0T667 467 | A0A2H5PSR6 468 | A0A1H9YA48 469 | A0A1A9ZW06 470 | C8C8F1 471 | A0A2H0BGN3 472 | A0A202DFP8 473 | A0A183AEG7 474 | A0A1D6ABN7 475 | A0A2A2L697 476 | A0A0C9T914 477 | A0A1B1DZS6 478 | A0A2X4HES6 479 | A0A1N6R7T0 480 | A0A1H0NM20 481 | A0A1E7IWB0 482 | A0A2V2C1E0 483 | B4HWW2 484 | A0A165BZI2 485 | A0A088FW42 486 | A0A2K8Y942 487 | G4N1M6 488 | A0A2W5Z9K1 489 | A0A2T2UKC3 490 | H3N8X8 491 | A0A1I8G1K9 492 | A0A2G2FSW5 493 | F2QZ14 494 | A0A1Y1R922 495 | A0A0A9AYF5 496 | A0A0P5XD61 497 | J4CCD8 498 | A0A1V0A0Z2 499 | A0A0R2TA27 500 | A0A0P5DYX6 501 | -------------------------------------------------------------------------------- /data/example_pretrained_model/config.json: -------------------------------------------------------------------------------- 1 | {"d_model": 128, "n_layer": 8, "vocab_size": 33, "ssm_cfg": {}, "rms_norm": true, "residual_in_fp32": false, "fused_add_norm": true, "pad_vocab_size_multiple": 8} -------------------------------------------------------------------------------- /data/example_pretrained_model/optimizer.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bitbol-Lab/ProtMamba-ssm/4dfb7e199542a88babbc4f6b724b09dad8eeacbb/data/example_pretrained_model/optimizer.pt -------------------------------------------------------------------------------- /data/example_pretrained_model/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bitbol-Lab/ProtMamba-ssm/4dfb7e199542a88babbc4f6b724b09dad8eeacbb/data/example_pretrained_model/pytorch_model.bin -------------------------------------------------------------------------------- /data/example_pretrained_model/rng_state.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bitbol-Lab/ProtMamba-ssm/4dfb7e199542a88babbc4f6b724b09dad8eeacbb/data/example_pretrained_model/rng_state.pth -------------------------------------------------------------------------------- /data/example_pretrained_model/scheduler.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bitbol-Lab/ProtMamba-ssm/4dfb7e199542a88babbc4f6b724b09dad8eeacbb/data/example_pretrained_model/scheduler.pt -------------------------------------------------------------------------------- /data/logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bitbol-Lab/ProtMamba-ssm/4dfb7e199542a88babbc4f6b724b09dad8eeacbb/data/logo.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.4 2 | torch==2.3.1 3 | tqdm==4.64.1 4 | pandas==1.4.4 5 | pathlib==1.0.1 6 | datasets==2.10.1 7 | scipy==1.9.1 8 | requests==2.31.0 9 | argparse==1.4.0 10 | scikit-learn==1.3.0 11 | transformers[torch]==4.44.2 12 | matplotlib==3.5.2 13 | biopython==1.78 14 | tensorboard==2.17.1 15 | dataclasses==0.6 16 | mamba-ssm==2.1.0 17 | wandb==0.17.0 -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | # All sections below are required unless otherwise specified. 3 | # See https://github.com/fastai/nbdev/blob/master/settings.ini for examples. 4 | 5 | ### Python library ### 6 | repo = ProtMamba-ssm 7 | lib_name = %(repo)s 8 | version = 0.0.1 9 | min_python = 3.7 10 | license = apache2 11 | black_formatting = False 12 | 13 | ### nbdev ### 14 | doc_path = _docs 15 | lib_path = ProtMamba_ssm 16 | nbs_path = nbs 17 | recursive = True 18 | tst_flags = notest 19 | put_version_in_init = True 20 | 21 | ### Docs ### 22 | branch = main 23 | custom_sidebar = False 24 | doc_host = https://%(user)s.github.io 25 | doc_baseurl = /%(repo)s 26 | git_url = https://github.com/%(user)s/%(repo)s 27 | title = %(lib_name)s 28 | 29 | ### PyPI ### 30 | audience = Developers 31 | author = damiano-sg 32 | author_email = damianosg96@gmail.com 33 | copyright = 2024 onwards, %(author)s 34 | description = Code to run ProtMamba locally 35 | keywords = nbdev jupyter notebook python 36 | language = English 37 | status = 3 38 | user = Bitbol-Lab 39 | 40 | ### Optional ### 41 | requirements = numpy torch biopython transformers matplotlib tensorboard pandas mamba-ssm==1.1.4 42 | # dev_requirements = 43 | # console_scripts = 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools, shlex 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini', encoding='utf-8') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = shlex.split(cfg.get('requirements', '')) 28 | if cfg.get('pip_requirements'): requirements += shlex.split(cfg.get('pip_requirements', '')) 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md', encoding='utf-8').read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | -------------------------------------------------------------------------------- /tests/generation/sample_sequences.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from ProtMamba_ssm.core import *\n", 10 | "from ProtMamba_ssm.dataloaders import *\n", 11 | "from ProtMamba_ssm.utils import *\n", 12 | "from ProtMamba_ssm.modules import *\n", 13 | "import torch\n", 14 | "import numpy as np\n", 15 | "from tqdm import tqdm\n", 16 | "import pickle" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "def sample_sequences(dataset,\n", 26 | " model,\n", 27 | " n_samples_per_family,\n", 28 | " max_length=1000,\n", 29 | " family_idxs=[],\n", 30 | " parameters_list=[],\n", 31 | " fim_generation=False,\n", 32 | " save_path=None):\n", 33 | " \"\"\"\n", 34 | " Function to sample sequences from the model. Given a dataset, a list of families (their indexes in the dataset)\n", 35 | " and a set of generating parameters, it generates `n_samples_per_family` sequences for each family and each parameter set.\n", 36 | " The function returns a dictionary with the following structure:\n", 37 | " gen_seqs = {family_idx: {parameters: {sequence: perplexity}}}\n", 38 | " The parameters are in a list of tuples with the following structure: \n", 39 | " parameters_list = [(nr_seqs_ctx, temperature, top_k, top_p)]\n", 40 | " \"\"\" \n", 41 | " gen_seqs = {}\n", 42 | " for j in family_idxs:\n", 43 | " gen_seqs[j] = {}\n", 44 | " print(\"Sampling sequences for family {}\".format(j))\n", 45 | " for params in tqdm(parameters_list):\n", 46 | " gen_seqs[j][params] = {}\n", 47 | " n_seqs_ctx , temperature, top_k, top_p = params\n", 48 | " for _ in range(n_samples_per_family):\n", 49 | " # Sample the dataset to get the input\n", 50 | " data = dataset[j]\n", 51 | " tokens = data[\"input_ids\"][None,:].to(\"cuda\")\n", 52 | " pos_ids = data[\"position_ids\"][None,:].to(\"cuda\")\n", 53 | " start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()\n", 54 | " if fim_generation:\n", 55 | " n_seqs_ctx = len(start_seqs)-1 if len(start_seqs) < n_seqs_ctx+1 else n_seqs_ctx\n", 56 | " L = start_seqs[n_seqs_ctx+1] if n_seqs_ctx>0 else start_seqs[n_seqs_ctx]\n", 57 | " context_tokens, context_pos_ids, tokens_fim, pos_ids_fim, is_fim_dict = prepare_dataset_for_fim_generation(tokens[:,:L], pos_ids[:,:L])\n", 58 | " is_fim = is_fim_dict\n", 59 | " else:\n", 60 | " n_seqs_ctx = len(start_seqs) if len(start_seqs) < n_seqs_ctx else n_seqs_ctx\n", 61 | " L = start_seqs[n_seqs_ctx]+1\n", 62 | " context_tokens = tokens[:,:L]\n", 63 | " context_pos_ids = pos_ids[:,:L]\n", 64 | " is_fim=False\n", 65 | " # Generate the new sequence \n", 66 | " output = generate_sequence(model,\n", 67 | " context_tokens,\n", 68 | " position_ids=context_pos_ids,\n", 69 | " is_fim=is_fim,\n", 70 | " max_length=(L+max_length),\n", 71 | " temperature=temperature,\n", 72 | " top_k=top_k,\n", 73 | " top_p=top_p,\n", 74 | " return_dict_in_generate=True,\n", 75 | " output_scores=True,\n", 76 | " eos_token_id=torch.tensor([AA_TO_ID[\"\"]]).to(\"cuda\"),\n", 77 | " device=\"cuda\")\n", 78 | " # Get the perplexity of the generated sequence\n", 79 | " output_seq = output[\"generated\"] \n", 80 | " loss = torch.nn.functional.cross_entropy(torch.from_numpy(output[\"scores\"]).permute(0, 2, 1),\n", 81 | " torch.from_numpy(output[\"generated_tokens\"][0][None,:]))\n", 82 | " # save only sequences with length < max_length\n", 83 | " if len(output_seq[0]) < max_length:\n", 84 | " if fim_generation:\n", 85 | " original_input = output[\"input\"][0].split(\"\")[-1]\n", 86 | " original_input_continuation = decode_sequence(tokens_fim[0].cpu().numpy())+\"\"\n", 87 | " generated_input_continuation = output_seq[0]\n", 88 | " if len(original_input_continuation) == len(generated_input_continuation):\n", 89 | " outp_str = reorder_masked_sequence(original_input + generated_input_continuation)\n", 90 | " gen_seqs[j][params][outp_str] = {\"original_input\": original_input,\n", 91 | " \"original_input_fim\": original_input_continuation,\n", 92 | " \"generated_input_fim\": generated_input_continuation,\n", 93 | " \"perplexity\": torch.exp(loss).item()}\n", 94 | " else:\n", 95 | " print(\"Lengths of original and generated FIM do not match. {} vs {}\".format(original_input_continuation, generated_input_continuation))\n", 96 | " else:\n", 97 | " gen_seqs[j][params][output_seq[0]] = {\"perplexity\": torch.exp(loss).item()}\n", 98 | " if save_path is not None:\n", 99 | " with open(save_path, \"wb\") as f:\n", 100 | " pickle.dump(gen_seqs, f)\n", 101 | " return gen_seqs" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# Load the dataset used for training\n", 111 | "dataset_name = \"encoded_MSAs_test.pkl\"\n", 112 | "fim_strategy = \"multiple_span\"\n", 113 | "mask_fraction = 0.2\n", 114 | "dataset = Uniclust30_Dataset(dataset_name,\n", 115 | " filepath=\"/data1/common/OpenProteinSet/\",\n", 116 | " sample=False,\n", 117 | " max_msa_len=-1,\n", 118 | " max_patches=5,\n", 119 | " mask_fraction=mask_fraction,\n", 120 | " fim_strategy=fim_strategy,\n", 121 | " max_position_embeddings=2048,\n", 122 | " add_position_ids=\"1d\")\n", 123 | " \n", 124 | "# Load pretrained model\n", 125 | "checkpoint = \"../../nbs/results/train_100M_FIM_restart-spikes_merged/checkpoint_131k-3750\"\n", 126 | "model = load_model(checkpoint,\n", 127 | " model_class=MambaLMHeadModelwithPosids,\n", 128 | " device=\"cuda\",\n", 129 | " dtype=torch.bfloat16,\n", 130 | " checkpoint_mixer=False\n", 131 | " )\n", 132 | "model = model.eval()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Sample from different families using different generation methods" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# family_idxs = [0, 1, 2, 3, 4, 5, 6, 8, 9, 10]\n", 149 | "family_idxs = [11, 13, 14, 16, 18] # [20, 21, 23, 24, 25] # \n", 150 | "\n", 151 | "# # parameters: (nr_seqs_ctx, temperature, top_k, top_p)\n", 152 | "# parameters_list = [(100,1.,10,0.), (-1,0.9,10,0.95)]\n", 153 | "parameters_list = [(10,1.,10,0.), (10,1.,15,0.), (10,1.,10,0.95), (10,0.9,10,0.95), (10,0.8,10,0.9),\n", 154 | " (100,1.,10,0.), (100,1.,15,0.), (100,1.,10,0.95), (100,0.9,10,0.95), (100,0.8,10,0.9),\n", 155 | " (500,1.,10,0.), (500,1.,15,0.), (500,1.,10,0.95), (500,0.9,10,0.95), (500,0.8,10,0.9),\n", 156 | " (1000,1.,10,0.), (1000,1.,15,0.), (1000,1.,10,0.95), (1000,0.9,10,0.95), (1000,0.8,10,0.9),\n", 157 | " (-1,1.,10,0.), (-1,1.,15,0.), (-1,1.,10,0.95), (-1,0.9,10,0.95), (-1,0.8,10,0.9)]\n", 158 | "n_samples_per_family = 100\n", 159 | "generate_fim = False" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "for i in family_idxs:\n", 169 | " data = dataset[i]\n", 170 | " tokens = data[\"input_ids\"][None,:].to(\"cuda\")\n", 171 | " start_seqs = torch.argwhere(tokens[0]==0)[:,0].cpu().numpy()\n", 172 | " inds = start_seqs < 131072\n", 173 | " print(f\"Family name: {dataset.cluster_names[i]}\\t\", \"\\tNumber of sequences: \", len(start_seqs), \"\\tNum sequences < 131072: \", len(start_seqs[inds]))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "end_str = \"_fim\" if generate_fim else \"_full\"\n", 183 | "save_path = f\"figures/generated_sequences/check-131k(11-18)_gen_seqs{end_str}.pkl\"\n", 184 | "gen_seqs = sample_sequences(dataset,\n", 185 | " model,\n", 186 | " n_samples_per_family=n_samples_per_family,\n", 187 | " max_length=1000,\n", 188 | " family_idxs=family_idxs,\n", 189 | " parameters_list=parameters_list,\n", 190 | " fim_generation=generate_fim,\n", 191 | " save_path=save_path\n", 192 | " )\n", 193 | "with open(save_path, \"wb\") as f:\n", 194 | " pickle.dump(gen_seqs, f)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "ProtMamba", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.9.18" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 2 226 | } 227 | -------------------------------------------------------------------------------- /tests/proteingym/mmseqs_search.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import random 5 | import shutil 6 | import subprocess 7 | import tarfile 8 | import time 9 | from typing import List 10 | 11 | import requests 12 | from tqdm import tqdm 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def run_mmseqs2_api(x, prefix, use_env=True, use_filter=True, filter=None, host_url="https://api.colabfold.com", 18 | user_agent: str = "") -> List[str]: 19 | submission_endpoint = "ticket/msa" 20 | 21 | headers = {} 22 | if user_agent != "": 23 | headers['User-Agent'] = user_agent 24 | else: 25 | logger.warning( 26 | "No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.") 27 | 28 | def submit(seqs, mode, N=101): 29 | n, query = N, "" 30 | for seq in seqs: 31 | query += f">{n}\n{seq}\n" 32 | n += 1 33 | 34 | while True: 35 | error_count = 0 36 | try: 37 | # https://requests.readthedocs.io/en/latest/user/advanced/#advanced 38 | # "good practice to set connect timeouts to slightly larger than a multiple of 3" 39 | res = requests.post(f'{host_url}/{submission_endpoint}', data={'q': query, 'mode': mode}, timeout=6.02, 40 | headers=headers) 41 | except requests.exceptions.Timeout: 42 | logger.warning("Timeout while submitting to MSA server. Retrying...") 43 | continue 44 | except Exception as e: 45 | error_count += 1 46 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") 47 | logger.warning(f"Error: {e}") 48 | time.sleep(5) 49 | if error_count > 5: 50 | raise 51 | continue 52 | break 53 | 54 | try: 55 | out = res.json() 56 | except ValueError: 57 | logger.error(f"Server didn't reply with json: {res.text}") 58 | out = {"status": "ERROR"} 59 | return out 60 | 61 | def status(ID): 62 | while True: 63 | error_count = 0 64 | try: 65 | res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers) 66 | except requests.exceptions.Timeout: 67 | logger.warning("Timeout while fetching status from MSA server. Retrying...") 68 | continue 69 | except Exception as e: 70 | error_count += 1 71 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") 72 | logger.warning(f"Error: {e}") 73 | time.sleep(5) 74 | if error_count > 5: 75 | raise 76 | continue 77 | break 78 | try: 79 | out = res.json() 80 | except ValueError: 81 | logger.error(f"Server didn't reply with json: {res.text}") 82 | out = {"status": "ERROR"} 83 | return out 84 | 85 | def download(ID, path): 86 | error_count = 0 87 | while True: 88 | try: 89 | res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers) 90 | except requests.exceptions.Timeout: 91 | logger.warning("Timeout while fetching result from MSA server. Retrying...") 92 | continue 93 | except Exception as e: 94 | error_count += 1 95 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") 96 | logger.warning(f"Error: {e}") 97 | time.sleep(5) 98 | if error_count > 5: 99 | raise 100 | continue 101 | break 102 | with open(path, "wb") as out: 103 | out.write(res.content) 104 | 105 | # process input x 106 | seqs = [x] if isinstance(x, str) else x 107 | 108 | # compatibility to old option 109 | if filter is not None: 110 | use_filter = filter 111 | 112 | # setup mode 113 | if use_filter: 114 | mode = "env" if use_env else "all" 115 | else: 116 | mode = "env-nofilter" if use_env else "nofilter" 117 | 118 | # define path 119 | path = f"{prefix}_{mode}" 120 | if not os.path.isdir(path): 121 | os.mkdir(path) 122 | 123 | # call mmseqs2 api 124 | tar_gz_file = f'{path}/out.tar.gz' 125 | N, REDO = 101, True 126 | 127 | # deduplicate and keep track of order 128 | seqs_unique = [] 129 | [seqs_unique.append(x) for x in seqs if x not in seqs_unique] 130 | Ms = [N + seqs_unique.index(seq) for seq in seqs] 131 | # lets do it! 132 | if not os.path.isfile(tar_gz_file): 133 | TIME_ESTIMATE = 150 * len(seqs_unique) 134 | with tqdm(total=TIME_ESTIMATE) as pbar: 135 | while REDO: 136 | pbar.set_description("SUBMIT") 137 | out = submit(seqs_unique, mode, N) 138 | while out["status"] in ["UNKNOWN", "RATELIMIT"]: 139 | sleep_time = 5 + random.randint(0, 5) 140 | logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") 141 | time.sleep(sleep_time) 142 | out = submit(seqs_unique, mode, N) 143 | 144 | if out["status"] == "ERROR": 145 | raise Exception( 146 | f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') 147 | if out["status"] == "MAINTENANCE": 148 | raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.') 149 | 150 | # wait for job to finish 151 | ID, TIME = out["id"], 0 152 | pbar.set_description(out["status"]) 153 | while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: 154 | t = 5 + random.randint(0, 5) 155 | logger.error(f"Sleeping for {t}s. Reason: {out['status']}") 156 | time.sleep(t) 157 | out = status(ID) 158 | pbar.set_description(out["status"]) 159 | if out["status"] == "RUNNING": 160 | TIME += t 161 | pbar.update(n=t) 162 | 163 | if out["status"] == "COMPLETE": 164 | if TIME < TIME_ESTIMATE: 165 | pbar.update(n=(TIME_ESTIMATE - TIME)) 166 | REDO = False 167 | 168 | if out["status"] == "ERROR": 169 | raise Exception( 170 | f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') 171 | download(ID, tar_gz_file) 172 | 173 | a3m_files = [f"{path}/uniref.a3m"] 174 | 175 | # extract a3m files 176 | if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): 177 | with tarfile.open(tar_gz_file) as tar_gz: 178 | tar_gz.extractall(path) 179 | 180 | # gather a3m lines 181 | a3m_lines = {} 182 | for a3m_file in a3m_files: 183 | update_M, M = True, None 184 | for line in open(a3m_file, "r"): 185 | if len(line) > 0: 186 | if "\x00" in line: 187 | line = line.replace("\x00", "") 188 | update_M = True 189 | if line.startswith(">") and update_M: 190 | M = int(line[1:].rstrip()) 191 | update_M = False 192 | if M not in a3m_lines: 193 | a3m_lines[M] = [] 194 | a3m_lines[M].append(line) 195 | a3m_lines = ["".join(a3m_lines[n]) for n in Ms] 196 | return a3m_lines 197 | 198 | 199 | def full_results_to_individuals(fasta_file, result_file, outdir: str): 200 | r""" Split the full results into individual a3m files 201 | 202 | Args: 203 | fasta_file (str): Path to reference file 204 | result_file (str): Path to result file 205 | outdir (str): Output directory 206 | """ 207 | labels = [] 208 | sequences = [] 209 | with open(fasta_file, "r") as f: 210 | for line in f: 211 | if line.startswith(">"): 212 | labels.append("_".join(line[1:].strip().split("_")[:2])) 213 | sequences.append("") 214 | else: 215 | sequences[-1] += line.strip() 216 | seqs_unique = [] 217 | [seqs_unique.append(x) for x in sequences if x not in seqs_unique] 218 | Ms = [seqs_unique.index(sequences) for sequences in seqs] 219 | seqid_to_labels = {i: [] for i in range(len(seqs_unique))} 220 | for i, label in zip(Ms, labels): 221 | seqid_to_labels[i].append(label) 222 | 223 | is_first_seq = True 224 | filenames = [] 225 | with open(result_file, "r") as f: 226 | for line in f: 227 | if line.startswith(">"): 228 | label = line[1:].strip() 229 | if len(label) == 3: 230 | idx = int(label) - 101 231 | filenames = seqid_to_labels[idx] 232 | first_seq = seqs_unique[idx] 233 | for filename in filenames: 234 | with open(f"{outdir}/{filename}.a3m", "w") as f: 235 | f.write(f">{filename}\n") 236 | f.write(first_seq + "\n") 237 | is_first_seq = True 238 | else: 239 | for filename in filenames: 240 | with open(f"{outdir}/{filename}.a3m", "a") as f: 241 | f.write(line) 242 | else: 243 | if is_first_seq: 244 | is_first_seq = False 245 | else: 246 | for filename in filenames: 247 | with open(f"{outdir}/{filename}.a3m", "a") as f: 248 | f.write(line) 249 | 250 | 251 | if __name__ == "__main__": 252 | 253 | user_agent = ... # str: User agent to use for the API requests 254 | fasta_file = ... # str: Path to the fasta file that contain the sequences to query 255 | result_file = ... # str: Path to the result file that will contain the a3m results 256 | temp_folder = ... # str: Path to the temporary folder to store the intermediate files 257 | seqs = [] 258 | with open(fasta_file, "r") as f: 259 | for line in f: 260 | if line.startswith(">"): 261 | seqs.append("") 262 | else: 263 | seqs[-1] += line.strip() 264 | a3m_lines = run_mmseqs2_api(seqs, temp_folder, use_env=True, use_filter=True, 265 | filter=None, host_url="https://api.colabfold.com", user_agent=user_agent) 266 | with open(result_file, "w") as f: 267 | for line in a3m_lines: 268 | f.write(line) 269 | f.write("\n") 270 | output_dir = "/data2/malbrank/protein_gym/mmseqs_colabfold_protocol" 271 | if not os.path.exists(output_dir): 272 | os.makedirs(output_dir) 273 | full_results_to_individuals(fasta_file, result_file, output_dir) 274 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import math 3 | import os 4 | import pickle 5 | from pathlib import Path 6 | from typing import Optional, Callable, Union 7 | 8 | import numpy as np 9 | import torch 10 | from datasets import Dataset 11 | 12 | from ProtMamba_ssm import AA_TO_ID, NoFIM, SingleSpanFIM, MultipleSpanFIM 13 | 14 | 15 | class Uniclust30_Dataset(Dataset): 16 | """ 17 | Dataset class used to import the Uniclust30 folders. 18 | If `filename` = "encoded_MSAs.pkl", it will load the full dataset. 19 | If `filename` = "encoded_MSAs_subset.pkl", it will load a small subset of the dataset. 20 | If `sample` = True, it will sample a random number of sequences from each cluster. 21 | If `sample` = False, it will load all the sequences from each cluster (and shuffle them). 22 | To limit the length of the MSAs, set `max_msa_len` to a positive integer. 23 | If `reverse` = True, it will reverse the sequences with probability 0.5 and move the last token to the front. 24 | If `scrambling_strategy` = "no-scramble", it will not scramble the sequences and simply concatenate them. 25 | If `scrambling_strategy` = "OpenAI", it will scramble the sequences using the OpenAI strategy. 26 | If `scrambling_strategy` = "inpaint", it will scramble the sequences using the inpaint strategy. In this case it will use 27 | `max_patches` patches and mask `mask_fraction` of the patches. 28 | """ 29 | _FIM = {"no-scramble": NoFIM, "one_span": SingleSpanFIM, "multiple_span": MultipleSpanFIM} 30 | _POSIDS = {"none", "1d", "2d"} 31 | 32 | def __init__(self, filename="encoded_MSAs_train.pkl", 33 | filepath="/nvme1/common/OpenProteinSet/", 34 | sample=False, 35 | max_msa_len=-1, 36 | reverse=False, 37 | seed=42, 38 | troubleshoot=False, 39 | fim_strategy="no-scramble", 40 | max_patches=5, 41 | mask_fraction=0.2, 42 | always_mask=False, 43 | max_position_embeddings=2048, 44 | max_seq_position_embeddings=512, 45 | add_position_ids="none", ): 46 | np.random.seed(seed) 47 | self.path = filepath 48 | # self.path_clusters = self.path + "OpenProteinSet_uniclust30-filtered/" 49 | if filename: 50 | self.dataset = pickle.load(open(self.path + filename, "rb")) 51 | self.cluster_names = list(self.dataset.keys()) 52 | else: 53 | self.dataset = None 54 | self.cluster_names = [] 55 | self.sample = sample 56 | self.max_msa_len = max_msa_len 57 | self.reverse = reverse 58 | self.fim_strategy = fim_strategy 59 | if fim_strategy in Uniclust30_Dataset._FIM: 60 | self.fim = Uniclust30_Dataset._FIM[fim_strategy](max_patches=max_patches, 61 | mask_fraction=mask_fraction, 62 | always_mask=always_mask, 63 | add_position_ids=add_position_ids != "none", 64 | troubleshoot=troubleshoot) 65 | else: 66 | raise ValueError(f'Fill in the middle stragy "{fim_strategy}" not recognized.') 67 | self.max_position_embeddings = max_position_embeddings 68 | self.max_seq_position_embeddings = max_seq_position_embeddings 69 | self.add_position_ids = add_position_ids 70 | 71 | self.troubleshoot = troubleshoot 72 | 73 | def __len__(self): 74 | return len(self.cluster_names) 75 | 76 | def __getitem__(self, idx): 77 | # get all the sequences in the cluster 78 | sequences = self.get_sequences(idx) 79 | # get total number of sequences in the cluster and choose how many to sample 80 | orig_num_sequences = len(self.get_index_start_of_sequences(sequences)) 81 | num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences 82 | # sample the sequences 83 | sequences, position_ids = self.sample_sequences(sequences, num_sequences) 84 | # with probability 0.5, reverse the sequences and move the last token to the front 85 | sequences, position_ids = self.reverse_sequences(sequences, position_ids) if ( 86 | self.reverse and np.random.rand() > 0.5) else sequences, position_ids 87 | # limit the length of the MSA 88 | sequences = sequences[:self.max_msa_len] if self.max_msa_len > 0 else sequences 89 | if self.add_position_ids != "none": 90 | position_ids = position_ids[:self.max_msa_len] if self.max_msa_len > 0 else position_ids 91 | # convert to tensor 92 | sequences = torch.asarray(sequences, dtype=torch.int64) 93 | position_ids = torch.asarray(position_ids, dtype=torch.int64).clamp(0, 94 | self.max_position_embeddings - 1) if self.add_position_ids!="none" else None 95 | 96 | if self.troubleshoot: 97 | print( 98 | f"Cluster {idx} has {orig_num_sequences} sequences, of which {num_sequences} sampled now. Total MSA length: {len(sequences)}") 99 | if self.add_position_ids == "1d": 100 | return dict(input_ids=sequences, position_ids=position_ids, labels=sequences) 101 | if self.add_position_ids == "2d": 102 | seq_position_ids = (sequences == AA_TO_ID[""]).int().cumsum(-1).clamp(0, 103 | self.max_seq_position_embeddings - 1).contiguous() 104 | return dict(input_ids=sequences, position_ids=position_ids, seq_position_ids=seq_position_ids, 105 | labels=sequences) 106 | return dict(input_ids=sequences, labels=sequences) 107 | 108 | def get_sequences(self, idx): 109 | """Get the sequences in the cluster with index `idx`.""" 110 | cluster_name = self.cluster_names[idx] 111 | sequences = self.dataset[cluster_name] 112 | return sequences 113 | 114 | def get_index_start_of_sequences(self, sequences): 115 | """Get the positions of the start of each sequence in the cluster.""" 116 | return np.where(sequences == 0)[0] 117 | 118 | def reverse_sequences(self, sequence, position_ids=None): 119 | """Reverse the sequences and move the last token to the front.""" 120 | sequence = sequence[::-1] 121 | if position_ids is not None: 122 | position_ids = position_ids[::-1] 123 | return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate( 124 | [position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None 125 | 126 | def sample_sequences(self, sequences, num_sequences, shuffle=True, which_seqs: Union[np.array, None]=None): 127 | """Sample `num_sequences` from the sequences in the cluster.""" 128 | L = len(sequences) 129 | # get the indexes of the start of each sequence 130 | inds = self.get_index_start_of_sequences(sequences) 131 | # check that there are sequences in the cluster and that there are enough of them 132 | assert len(inds) > 0, "No sequences found in cluster." 133 | assert len(inds) >= num_sequences, "Not enough sequences in cluster." 134 | # sample n_sequences randomly from the sequences 135 | if which_seqs is None: 136 | if shuffle: 137 | which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False) 138 | else: 139 | which_seqs = np.arange(len(inds))[-num_sequences:] 140 | # get the tuples of start and end indexes of the sequences 141 | tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs] 142 | if self.troubleshoot: 143 | print(f"Sampled sequences: {tuples}") 144 | # concatenate the sequences 145 | sequences, position_ids = self.fim.apply(sequences, tuples) 146 | return sequences, position_ids 147 | 148 | 149 | def compute_hamming_csim_torch( 150 | seqs: torch.Tensor, 151 | ungapped_msa: torch.Tensor, 152 | gap_token: int, 153 | gap_token_mask: int, 154 | ) -> torch.Tensor: 155 | return (seqs.unsqueeze(1) == ungapped_msa).sum(dim=2) 156 | 157 | def _compute_homology_weights( 158 | ungapped_msa: np.ndarray, 159 | gap_token: int, 160 | gap_token_mask: int, 161 | theta: float, 162 | hamming_csim_func: Callable, 163 | max_memory: int = 20, 164 | can_use_torch: bool = True, 165 | ) -> np.ndarray: 166 | use_torch = can_use_torch and torch.cuda.is_available() 167 | if use_torch: 168 | hamming_csim_func = compute_hamming_csim_torch 169 | batch_size = math.floor( 170 | 2 171 | * 1024 172 | * 1024 173 | * 1024 174 | / (ungapped_msa.shape[0] * ungapped_msa.shape[1]) 175 | * max_memory 176 | / 40 177 | ) 178 | 179 | batch_size = 1 if batch_size == 0 else batch_size 180 | 181 | neighbors = [] 182 | if not use_torch: 183 | masked_ungapped_msa = ungapped_msa.copy() 184 | else: 185 | ungapped_msa = torch.from_numpy(ungapped_msa).byte().cuda() 186 | masked_ungapped_msa = ungapped_msa.clone() 187 | masked_ungapped_msa[masked_ungapped_msa == gap_token] = gap_token_mask 188 | for b_start in range(0, len(ungapped_msa), batch_size): 189 | b_end = b_start + batch_size 190 | seqs = ungapped_msa[b_start:b_end] 191 | 192 | sim = hamming_csim_func( 193 | seqs=seqs, 194 | ungapped_msa=masked_ungapped_msa, 195 | gap_token=gap_token, 196 | gap_token_mask=gap_token_mask, 197 | ) 198 | if not use_torch: 199 | sim = sim / (seqs != gap_token).sum(axis=1, keepdims=True) 200 | d = 1 - sim 201 | d = d.clamp(0, 1) 202 | this_neighbors = (d <= theta).sum(axis=1) 203 | else: 204 | sim = sim / (seqs != gap_token).sum(dim=1, keepdim=True) 205 | d = 1 - sim 206 | # fillna 207 | d[torch.isnan(d)] = 0 208 | d = d.clamp(0, 1) 209 | this_neighbors = (d <= theta).sum(dim=1).cpu() 210 | neighbors.append(this_neighbors) 211 | return np.concatenate(neighbors) 212 | 213 | 214 | def compute_homology_weights( 215 | ungapped_msa: np.ndarray, 216 | theta: float = 0.2, 217 | gap_token: int = AA_TO_ID["-"], 218 | gap_token_mask: int = 255, 219 | hamming_csim_func: Callable = compute_hamming_csim_torch, 220 | ) -> tuple[int, np.ndarray]: 221 | """ 222 | Calculate the effective number of sequences and sampling probability for the NEIGHBORS and NEIGHBORS_NO_LIMIT sampling methods using numpy. 223 | 224 | Parameters: 225 | 226 | ungapped_msa (np.ndarray): The MSA (from .fa). 227 | theta (float, optional): A parameter used to determine the similarity between sequences. Default is 0.2. 228 | gap_token (int, optional): The token representing gaps in the (Uniprot21 encoded) MSA. Default is 20. 229 | gap_token_mask (int): token for masking gaps. should be a token not representing any other value. 230 | 231 | Returns: 232 | 233 | tuple[int, np.ndarray]: A tuple containing the effective number of sequences and the sampling probability for each sequence in the MSA. 234 | """ 235 | neighbors = _compute_homology_weights( 236 | ungapped_msa=ungapped_msa, 237 | gap_token=gap_token, 238 | gap_token_mask=gap_token_mask, 239 | theta=theta, 240 | hamming_csim_func=hamming_csim_func, 241 | ) 242 | n_eff = np.sum(1 / neighbors) 243 | 244 | p = 1 / neighbors 245 | p /= np.sum(p) 246 | return n_eff, p 247 | 248 | 249 | class MSASampler: 250 | 251 | def __init__(self, max_similarity, max_dissimilarity, force_include_first=True): 252 | self.max_similarity = max_similarity 253 | self.max_dissimilarity = max_dissimilarity 254 | self.force_include_first = force_include_first 255 | self.theta = 0.2 256 | 257 | def _get_sim_filtered_idxs(self, msa: np.ndarray) -> np.ndarray: 258 | nonnormalized_sim = (msa == msa[[0]]).sum(axis=1) 259 | normfactor = msa.shape[1] 260 | norm_sim = nonnormalized_sim / normfactor 261 | 262 | assert (norm_sim.min() >= 0) and (norm_sim.max() <= 1) 263 | dsim = 1 - norm_sim 264 | 265 | max_sim_filter = norm_sim <= self.max_similarity 266 | max_dissim_filter = dsim <= self.max_dissimilarity 267 | return np.where(max_sim_filter & max_dissim_filter)[0] 268 | 269 | def get_weights( 270 | self, msa: np.ndarray, 271 | ) -> tuple[Optional[float], Optional[np.ndarray]]: 272 | return compute_homology_weights( 273 | ungapped_msa=msa, 274 | theta=self.theta, 275 | gap_token_mask=255, 276 | 277 | ) 278 | 279 | def get_sample_idxs( 280 | self, 281 | msa: np.ndarray, 282 | name: str = None, 283 | size: int = 1, 284 | random = False, 285 | result_cache_dir: Optional[str] = "/data2/malbrank/protein_gym/cache", 286 | ) -> np.ndarray: 287 | if random: 288 | return np.random.choice(len(msa), replace=False, size=size) if len(msa) >= size else np.arange(len(msa)) 289 | msa = np.array([[AA_TO_ID[aa] for aa in seq.upper()][:len(msa[0])] for seq in msa], dtype=np.uint8) 290 | 291 | if name is not None and name+".npy" in os.listdir(result_cache_dir): 292 | weights = np.load(os.path.join(result_cache_dir, name+".npy")) 293 | elif name is not None: 294 | _, weights = self.get_weights( 295 | msa=msa, 296 | ) 297 | np.save(os.path.join(result_cache_dir, name), weights) 298 | else: 299 | _, weights = self.get_weights( 300 | msa=msa, 301 | ) 302 | 303 | 304 | original_msa_sample_idxs = np.arange(len(msa)) 305 | sample_idxs = self._get_sim_filtered_idxs(msa) 306 | original_msa_sample_idxs = original_msa_sample_idxs[sample_idxs] 307 | 308 | if self.force_include_first: 309 | original_msa_sample_idxs = np.concatenate( 310 | [[0], original_msa_sample_idxs[original_msa_sample_idxs != 0]] 311 | ) 312 | return np.random.choice(len(msa), replace=False, size=size, p=weights / weights.sum()) if len(msa) >= size else original_msa_sample_idxs -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | # Modules 3 | from ProtMamba_ssm.modules import * 4 | # Trainer 5 | from ProtMamba_ssm.trainer import * 6 | # Dataloaders 7 | from ProtMamba_ssm.dataloaders import * 8 | # Utils 9 | from ProtMamba_ssm.utils import * 10 | from transformers import TrainingArguments, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup, \ 11 | get_cosine_with_hard_restarts_schedule_with_warmup 12 | 13 | import yaml 14 | import torch 15 | import numpy as np 16 | from transformers.integrations import TensorBoardCallback 17 | from torch.optim import AdamW 18 | 19 | # List of available models 20 | _mamba_model = {"none": MambaLMHeadModelSafe, "1d": MambaLMHeadModelwithPosids, "2d": MambaLMHeadModelwith2DPosids} 21 | 22 | 23 | def run(config, namedir=None, finetune_model_path=None, restart_optimizer_and_scheduler=False): 24 | r"""Run the training/finetuning loop. 25 | 26 | Args: 27 | config (dict): dictionary with the configuration parameters. 28 | namedir (str, optional): name of the directory where the model will be saved. If None, the name will be taken from the config file. 29 | finetune_model_path (str, optional): path to the model to be finetuned. If None, a new model will be created. 30 | """ 31 | if namedir is None: 32 | namedir = config["namedir"] 33 | # Load Dataset 34 | full_dataset = Uniclust30_Dataset(filename=config["train_dataset_path"], 35 | filepath=config["data_dir"], 36 | sample=config["sample_sequences"], 37 | max_msa_len=config["max_msa_len"], 38 | reverse=config["reverse"], 39 | seed=config["seed_sequence_sampling"], 40 | troubleshoot=False, 41 | fim_strategy=config["fim_strategy"], 42 | always_mask=config["always_mask"], 43 | max_position_embeddings=config["max_position_embeddings"], 44 | max_seq_position_embeddings=config["max_seq_position_embeddings"], 45 | add_position_ids=config["add_position_ids"]) 46 | 47 | assert len(AA_TO_ID) == config[ 48 | "vocab_size"], f"Vocab size in the config file does not match the one in the code. I should be {len(AA_TO_ID)}" 49 | 50 | # Split dataset in train, validation and test 51 | gen = torch.Generator() 52 | gen.manual_seed(config["seed_datasets"]) 53 | np_gen = np.random.default_rng(seed=config["seed_datasets"]) 54 | len_full_dataset = len(full_dataset) 55 | len_val = config["size_validation"] # len_full_dataset - len_train 56 | len_train = len_full_dataset - len_val # int(0.99 * len_full_dataset) 57 | assert len_val < len_full_dataset, "Validation set is larger than the full dataset" 58 | assert len_val % config["batch_size"] == 0, "Validation set size must be a multiple of the batch size" 59 | print(f"Train: {len_train} samples, Validation: {len_val} samples") 60 | train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [len_train, len_val], generator=gen) 61 | eval_train_dataset = torch.utils.data.Subset(train_dataset, 62 | np_gen.choice(len(train_dataset), len(eval_dataset))) 63 | # Create data collator for batched training 64 | data_collator = DataCollatorForUniclust30Dataset() 65 | 66 | # Configure Mamba model 67 | Mamba = _mamba_model[config["add_position_ids"]] 68 | if config["dtype"] == "float32": 69 | dtype = torch.float32 70 | print("Using float32") 71 | elif config["dtype"] == "bfloat16": 72 | dtype = torch.bfloat16 73 | print("Using bfloat16") 74 | else: 75 | raise ValueError("dtype must be either float32 or bfloat16") 76 | if finetune_model_path is not None: 77 | # Load model for finetuning 78 | model = load_model(finetune_model_path, device="cuda", model_class=Mamba, dtype=dtype, 79 | checkpoint_mixer=config["checkpoint_mixer"]) 80 | else: 81 | # Create model for training 82 | mamba_config = MambaConfig(d_model=config["d_model"], 83 | n_layer=config["n_layer"], 84 | vocab_size=config["vocab_size"], 85 | residual_in_fp32=config["residual_in_fp32"]) 86 | model = Mamba(mamba_config, dtype=dtype, checkpoint_mixer=config["checkpoint_mixer"]) 87 | 88 | # Print model information 89 | print_number_of_parameters(model) 90 | print(f"Epochs: {config['num_epochs']}") 91 | print(f"Batch size: {config['batch_size']}") 92 | print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}") 93 | eff_batch_size = config['batch_size'] * config['gradient_accumulation_steps'] 94 | nr_gpus = torch.cuda.device_count() 95 | eff_batch_size *= nr_gpus 96 | print(f"Effective batch size: {eff_batch_size}") 97 | print( 98 | f"Steps per training epoch: {len(train_dataset) // config['batch_size']}, eff. steps: {len(train_dataset) // eff_batch_size}") 99 | print(f"Steps per evaluation epoch: {len(eval_dataset) // config['batch_size']}") 100 | print(f"Max MSA length: {config['max_msa_len']}") 101 | ev_epochs = round(config['eval_steps'] * config["batch_size"] / len(train_dataset), 3) 102 | print( 103 | f"Evaluation every {config['eval_steps']} steps, i.e. {ev_epochs} epochs. Effectively every {config['eval_steps'] * config['gradient_accumulation_steps']} steps, i.e. {ev_epochs * config['gradient_accumulation_steps']} epochs.") 104 | 105 | # Training callbacks 106 | es_callback = EarlyStoppingCallback(train_path=config["output_dir"] + namedir, config=config) 107 | callbacks = [TensorBoardCallback(), es_callback] 108 | 109 | # Optimizer and Schedulers 110 | optimizer = AdamW(model.parameters(), 111 | lr=config["learning_rate"], 112 | betas=(config["beta1"], config["beta2"]), 113 | weight_decay=config["weight_decay"]) 114 | if config["scheduler"] == "cosine": 115 | print("Using cosine scheduler") 116 | scheduler = get_cosine_schedule_with_warmup(optimizer, 117 | num_warmup_steps=config["warmup_steps"], 118 | num_training_steps=config["num_epochs"] * len( 119 | train_dataset) // eff_batch_size) 120 | if config["scheduler"] == "cosine-restarts": 121 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, 122 | num_warmup_steps=config["warmup_steps"], 123 | num_training_steps=config["num_epochs"] * len( 124 | train_dataset) // eff_batch_size, 125 | num_cycles=config["num_cycles"]) 126 | elif config["scheduler"] == "constant": 127 | print("Using constant scheduler with warmup") 128 | scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=config["warmup_steps"]) 129 | else: 130 | raise ValueError("Scheduler must be either cosine or constant") 131 | 132 | if finetune_model_path is not None: 133 | if not restart_optimizer_and_scheduler: 134 | optimizer.load_state_dict(torch.load(finetune_model_path + "/optimizer.pt")) 135 | scheduler.load_state_dict(torch.load(finetune_model_path + "/scheduler.pt")) 136 | 137 | # Find checkpoint if available 138 | last_checkpoint = None 139 | if finetune_model_path is None: 140 | if os.path.exists(config["output_dir"] + namedir): 141 | last_checkpoint = get_last_checkpoint(config["output_dir"] + namedir) 142 | if last_checkpoint is None: 143 | print("No checkpoint found, starting training from scratch.") 144 | else: 145 | print(f"Resuming training from the last checkpoint: {last_checkpoint}") 146 | if config["checkpoint_mixer"]: 147 | print("Using gradient checkpointing") 148 | # Create model trainer 149 | trainer = MambaTrainer( 150 | model=model, 151 | train_dataset=train_dataset, 152 | eval_dataset={"valid": eval_dataset, "train": eval_train_dataset}, 153 | optimizers=(optimizer, scheduler), 154 | args=TrainingArguments( 155 | learning_rate=config["learning_rate"], 156 | num_train_epochs=config["num_epochs"], 157 | per_device_train_batch_size=config["batch_size"], 158 | per_device_eval_batch_size=config["batch_size"], 159 | gradient_accumulation_steps=config["gradient_accumulation_steps"], 160 | eval_accumulation_steps=config["eval_accumulation_steps"], 161 | eval_strategy="steps", 162 | max_grad_norm=config["max_grad_norm"], 163 | bf16=config["dtype"] == "bfloat16", 164 | dataloader_num_workers=1, 165 | logging_steps=config["logging_steps"], 166 | eval_steps=config["eval_steps"], 167 | save_steps=config["save_steps"], 168 | output_dir=config["output_dir"] + namedir, 169 | logging_dir=config["output_dir"] + namedir, 170 | overwrite_output_dir=False, 171 | push_to_hub=False, 172 | label_names=["labels"], 173 | 174 | ), 175 | compute_only_fim_loss=config["compute_only_fim_loss"], 176 | data_collator=data_collator, 177 | compute_metrics=compute_metrics, 178 | callbacks=callbacks) 179 | assert trainer.args._n_gpu == nr_gpus, "Number of gpus used is not the same as the number of gpus available" 180 | if trainer.compute_only_fim_loss: 181 | print("Computing only FIM loss for training") 182 | # Train model 183 | while True: 184 | if last_checkpoint is None and trainer.state.global_step == 0: 185 | eval_results = trainer.evaluate() 186 | print(f">>> Initial Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}") 187 | else: 188 | print(f"Resuming training from the last checkpoint: {last_checkpoint}") 189 | # Train 190 | trainer.train(resume_from_checkpoint=last_checkpoint) 191 | # Break training when the number of epochs is reached 192 | if not es_callback.should_restart or trainer.state.epoch >= config["num_epochs"]: 193 | eval_results = trainer.evaluate() 194 | print(f">>> Final Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}") 195 | break 196 | # If the training was interrupted because of a loss spike, restart from the last checkpoint 197 | last_checkpoint = es_callback.checkpoint_path 198 | 199 | return trainer 200 | 201 | 202 | if __name__ == "__main__": 203 | use_one_gpu = "0" 204 | if int(use_one_gpu) >= 0: 205 | print(f"Using gpu {use_one_gpu}") 206 | print("Number of gpus used: ", torch.cuda.device_count()) 207 | 208 | 209 | 210 | with open("/home/user/scripts/configs/default_config.yaml", "r") as file: 211 | defaultconfig = yaml.safe_load(file) 212 | 213 | namedir = "test" 214 | trainer = run(defaultconfig, namedir) 215 | 216 | run(defaultconfig, namedir=None, finetune_model_path=None, restart_optimizer_and_scheduler=False) --------------------------------------------------------------------------------