├── .gitignore ├── LICENSE ├── README.md ├── assets └── chat_example.png ├── examples ├── protein │ ├── generate.py │ ├── p2_sampling_demo.ipynb │ └── utils.py └── text │ └── LLaDA │ ├── LICENSE │ ├── README.md │ ├── eval_lm_harness.py │ ├── generate.py │ └── scripts │ ├── gsm8k.sh │ ├── gsm8k_instruct.sh │ ├── humaneval.sh │ └── test_LLaDA_code.sh ├── setup.py └── src └── path_planning ├── __init__.py ├── p2.py ├── scheduler.py ├── score_function.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | 177 | 178 | examples/text/LLaDA/results 179 | examples/text/LLaDA/results-cp -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # P2 Sampling 2 | 3 | A Python package implementing [P2 (Path Planning)](https://arxiv.org/pdf/2502.03540), a masked diffusion model sampling method for sequence generation. This repository provides a flexible implementation that can be applied to various domains, with example implementations for protein sequence generation and text generation. 4 | 5 | ## Overview 6 | 7 | P2 sampling is a drop-in masked diffusion model sampler. 8 | 9 | Key advantages of P2: 10 | - Simple implementation. The core code is less than 100 LOC. 11 | - Modular Components for plug-and-play experimentation. 12 | - Applicable to various sequence domains (protein, text, etc.) 13 | 14 | ## Installation 15 | 16 | ### Basic Installation 17 | 18 | ```bash 19 | # Clone the repository 20 | git clone git@github.com:pengzhangzhi/path_planning.git 21 | cd path_planning 22 | 23 | # Install the package 24 | pip install -e . 25 | ``` 26 | 27 | ## Examples 28 | 29 | This repository includes example implementations for two domains: 30 | 31 | ### 1. Protein Sequence Generation 32 | 33 | The protein example demonstrates how to generate novel protein sequences using P2 sampling with ESM-2 models and evaluate their quality using ESMFold. 34 | 35 | #### Running the Protein Example 36 | 37 | ```bash 38 | # Basic generation 39 | python examples/protein/generate.py --num_seqs 10 --seq_len 128 40 | 41 | # With ESMFold evaluation 42 | python examples/protein/generate.py --num_seqs 10 --seq_len 128 --esmfold_eval --save_dir results/test_run 43 | ``` 44 | 45 | #### Jupyter Notebook 46 | 47 | For an interactive demonstration, you can also use the Jupyter notebook: 48 | 49 | ```bash 50 | examples/protein/p2_sampling_demo.ipynb 51 | ``` 52 | 53 | ### 2. Text Generation (LLaDA) 54 | 55 | The text example implements [LLaDA](https://arxiv.org/abs/2502.09992), a diffusion-based text generation approach using language models. 56 | 57 | #### Running the Text Example 58 | 59 | ```bash 60 | # Navigate to the text example directory 61 | cd examples/text/LLaDA 62 | 63 | # Run the generation script 64 | python generate.py 65 | ``` 66 | 67 | #### Chat Example 68 | 69 | ```bash 70 | cd examples/text/LLaDA 71 | python chat.py 72 | ``` 73 | Here is an example of my chat history: 74 | 75 | ![alt text](assets/chat_example.png) 76 | 77 | ## API Usage 78 | 79 | You can use the P2 sampling functionality programmatically in your own projects: 80 | 81 | ```python 82 | from path_planning import p2_sampling, seed_everything 83 | from path_planning.score_function import logP 84 | 85 | # Set random seed for reproducibility 86 | seed_everything(42) 87 | 88 | # Create a model decorator that makes the model return logits 89 | ModelWrapper = lambda model: lambda x: model(x).logits 90 | 91 | model_wrapper = ModelWrapper(your_model) 92 | 93 | # Use P2 sampling in your code 94 | sampled_sequence = p2_sampling( 95 | xt=initial_masked_sequence, 96 | model=model_wrapper, 97 | mask_id=your_mask_token_id, 98 | num_steps=128, 99 | tau=1.0, 100 | eta=1.0, 101 | score_fn=logP 102 | ) 103 | ``` 104 | 105 | ## Minimal P2-self-Planning Implementation 106 | ```python 107 | 108 | import torch 109 | from tqdm import tqdm 110 | from typing import Callable, Tuple, Any 111 | 112 | 113 | def topk_masking(scores: torch.Tensor, cutoff_len: torch.Tensor, mode: str = "lowest") -> torch.Tensor: 114 | """Generate a mask selecting the top-k lowest or highest elements per row.""" 115 | sorted_scores = scores.sort(dim=-1, descending=(mode == "highest")).values 116 | cutoff = sorted_scores.gather(dim=-1, index=cutoff_len) 117 | return (scores >= cutoff) if mode == "highest" else (scores < cutoff) 118 | 119 | 120 | def sample_categorical( 121 | logits: torch.Tensor, temperature: float = 1.0, noise_scale: float = 1.0 122 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 123 | """ 124 | Sample from a categorical distribution with optional Gumbel noise. 125 | Returns sampled tokens, their scores, and the noised logits. 126 | """ 127 | logits = logits.to(torch.float64) 128 | if temperature > 0: 129 | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8) 130 | logits = logits / temperature + noise_scale * gumbel_noise 131 | log_probs = logits.log_softmax(dim=-1) 132 | scores, tokens = log_probs.max(dim=-1) 133 | return tokens, scores.to(logits.dtype), logits.to(logits.dtype) 134 | 135 | 136 | @torch.inference_mode() 137 | @torch.amp.autocast(device_type="cuda", dtype=torch.float16) 138 | def p2_sampling( 139 | xt: torch.Tensor, 140 | model: Any, 141 | mask_id: int, 142 | num_steps: int, 143 | tau: float = 1.0, 144 | kappa_fn: Callable[[float], float] = lambda t: t, 145 | eta: float = 1.0, 146 | **kwargs 147 | ) -> torch.Tensor: 148 | """ 149 | P2 Sampling implementation for discrete diffusion models. 150 | Reference: https://arxiv.org/pdf/2502.03540 151 | """ 152 | dt = 1 / num_steps 153 | fix_mask = (xt != mask_id) 154 | 155 | for i in tqdm(range(1, num_steps + 1)): 156 | t = i * dt 157 | kappa_t = kappa_fn(t) 158 | 159 | logits = model(xt).double() 160 | last_mask = (xt == mask_id) 161 | unmask_t = ~last_mask & ~fix_mask 162 | 163 | x0, score, _ = sample_categorical(logits, temperature=tau) 164 | score = score.masked_fill(fix_mask, float("inf")) 165 | score[unmask_t] *= eta 166 | 167 | num_to_mask = ((~fix_mask).sum(dim=1, keepdim=True).float() * (1 - kappa_t)).long() 168 | to_mask = topk_masking(score, num_to_mask, mode="lowest") 169 | 170 | xt[to_mask] = mask_id 171 | mask_2_x0 = last_mask & ~to_mask 172 | xt[mask_2_x0] = x0[mask_2_x0] 173 | 174 | xt[xt == mask_id] = x0[xt == mask_id] 175 | return xt 176 | 177 | ``` 178 | 179 | ## Appreciation 180 | 181 | The code is based on the following repository: 182 | 183 | - [DPLM](https://github.com/bytedance/dplm) 184 | - [LLaDA](https://github.com/ML-GSAI/LLaDA) 185 | 186 | 187 | ## Citation 188 | 189 | ```bibtex 190 | @misc{peng2025pathplanningmaskeddiffusion, 191 | title={Path Planning for Masked Diffusion Model Sampling}, 192 | author={Fred Zhangzhi Peng and Zachary Bezemek and Sawan Patel and Jarrid Rector-Brooks and Sherwood Yao and Alexander Tong and Pranam Chatterjee}, 193 | year={2025}, 194 | eprint={2502.03540}, 195 | archivePrefix={arXiv}, 196 | primaryClass={cs.LG}, 197 | url={https://arxiv.org/abs/2502.03540}, 198 | } 199 | ``` 200 | 201 | -------------------------------------------------------------------------------- /assets/chat_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhangzhi/Path-Planning/68675ea2d9dc5d62f4d81d10cb75845b1bbbc9fb/assets/chat_example.png -------------------------------------------------------------------------------- /examples/protein/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from pathlib import Path 5 | from pprint import pprint 6 | import math 7 | import torch 8 | from transformers import AutoTokenizer, EsmForMaskedLM, AutoModel 9 | from path_planning.p2 import p2_sampling 10 | from path_planning.scheduler import sine_scheduler 11 | from path_planning.score_function import diff_top2 12 | from path_planning.utils import seed_everything 13 | from utils import run_esmfold_eval 14 | 15 | def ignore_special_tokens_logits(logits, tokenizer): 16 | """ 17 | Masks out the logits of special tokens to prevent them from being sampled. 18 | 19 | Args: 20 | logits (Tensor): Logits output from the model of shape [B, L, V]. 21 | tokenizer: The tokenizer to access special token IDs. 22 | 23 | Returns: 24 | Tensor: Modified logits with special tokens masked out. 25 | """ 26 | logits[..., tokenizer.mask_token_id] = -math.inf 27 | logits[..., tokenizer._token_to_id["X"]] = -math.inf 28 | logits[..., tokenizer.pad_token_id] = -math.inf 29 | logits[..., tokenizer.cls_token_id] = -math.inf 30 | logits[..., tokenizer.eos_token_id] = -math.inf 31 | return logits 32 | 33 | class ModelWrapper: 34 | """Wrapper for the ESM model to handle logits processing.""" 35 | def __init__(self, model, tokenizer): 36 | self.model = model 37 | self.tokenizer = tokenizer 38 | 39 | def __call__(self, x): 40 | outputs = self.model(x) 41 | logits = outputs.logits 42 | return ignore_special_tokens_logits(logits.float(), self.tokenizer) 43 | 44 | def create_masked_sequence(sequence_length: int, tokenizer, batch_size: int = 1, device: str = 'cuda'): 45 | """Create a fully masked sequence for generation.""" 46 | seq = [tokenizer.mask_token] * sequence_length 47 | sequences = [''.join(seq)] * batch_size 48 | 49 | encoded = tokenizer( 50 | sequences, 51 | add_special_tokens=True, 52 | padding=True, 53 | return_tensors='pt' 54 | ) 55 | return encoded['input_ids'].to(device) 56 | 57 | def save_sequences_to_fasta(sequences: list, seq_len: int, save_path: str): 58 | """Save generated sequences to FASTA format.""" 59 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 60 | with open(save_path, 'w') as fp: 61 | for idx, seq in enumerate(sequences): 62 | fp.write(f">SEQUENCE_{idx}_L={seq_len}\n") 63 | fp.write(f"{seq}\n") 64 | 65 | def generate_sequences( 66 | model_name: str = "facebook/esm2_t6_8M_UR50D", 67 | planner_name: str = None, 68 | num_seqs: int = 100, 69 | seq_len: int = 128, 70 | num_steps: int = 128, 71 | temperature: float = 1.0, 72 | eta: float = 1.0, 73 | seed: int = None, 74 | device: str = 'cuda', 75 | save_dir: str = 'generation-results', 76 | ) -> tuple[list, float]: 77 | """Generate protein sequences using P2 sampling.""" 78 | seed_everything(seed) 79 | 80 | print(f"Loading model {model_name}...") 81 | tokenizer = AutoTokenizer.from_pretrained(model_name) 82 | model = EsmForMaskedLM.from_pretrained(model_name) 83 | model = model.eval().to(device) 84 | 85 | model_wrapper = ModelWrapper(model, tokenizer) 86 | 87 | # Load planner if specified 88 | planner = None 89 | if planner_name: 90 | print(f"Loading planner model {planner_name}...") 91 | planner_tokenizer = AutoTokenizer.from_pretrained(planner_name) 92 | planner_model = AutoModel.from_pretrained(planner_name) 93 | planner_model = planner_model.eval().to(device) 94 | planner = ModelWrapper(planner_model, planner_tokenizer) 95 | 96 | print("Creating initial sequence...") 97 | xt = create_masked_sequence( 98 | sequence_length=seq_len, 99 | tokenizer=tokenizer, 100 | batch_size=num_seqs, 101 | device=device 102 | ) 103 | print(f"Initial sequence shape: {xt.shape}") 104 | 105 | print("Starting P2 sampling...") 106 | start_time = time.time() 107 | sampled_xt = p2_sampling( 108 | xt=xt, 109 | model=model_wrapper, 110 | mask_id=tokenizer.mask_token_id, 111 | num_steps=num_steps, 112 | tau=temperature, 113 | eta=eta, 114 | planner=planner, 115 | ) 116 | 117 | elapsed_time = time.time() - start_time 118 | 119 | decoded_seqs = tokenizer.batch_decode(sampled_xt, skip_special_tokens=True) 120 | decoded_seqs = [''.join(seq.split()) for seq in decoded_seqs] 121 | 122 | # Save sequences 123 | save_path = os.path.join(save_dir, f"L_{seq_len}.fasta") 124 | save_sequences_to_fasta(decoded_seqs, seq_len, save_path) 125 | print(f"Saved sequences to {save_path}") 126 | 127 | return decoded_seqs, elapsed_time 128 | 129 | def parse_args(): 130 | parser = argparse.ArgumentParser(description="Protein Sequence Generation using P2 Sampling") 131 | parser.add_argument('--model_name', type=str, default="airkingbd/dplm_650m",) 132 | parser.add_argument('--planner_name', type=str, default=None,) 133 | parser.add_argument('--num_seqs', type=int, default=200,) 134 | parser.add_argument('--seq_len', type=int, default=200,) 135 | parser.add_argument('--num_steps', type=int, default=200,) 136 | parser.add_argument('--temperature', type=float, default=1.0, 137 | help="Sampling temperature") 138 | parser.add_argument('--eta', type=float, default=1.0, 139 | help="Stochasticity strength (0: deterministic, 1: default, >1: more stochastic)") 140 | parser.add_argument('--seed', type=int, default=42, 141 | help="Random seed for reproducibility") 142 | parser.add_argument('--save_dir', type=str, default='generation-results', 143 | help="Directory to save generated sequences") 144 | parser.add_argument('--esmfold_eval', action='store_true', default=False, 145 | help="Run ESMFold evaluation") 146 | parser.add_argument('--max_tokens_per_batch', type=int, default=1024, 147 | help="Maximum tokens per batch for ESMFold evaluation") 148 | parser.add_argument('--num_recycles', type=int, default=None, 149 | help="Number of recycles for ESMFold") 150 | parser.add_argument('--cpu_only', action='store_true', 151 | help="Use CPU only for ESMFold") 152 | parser.add_argument('--cpu_offload', action='store_true', 153 | help="Enable CPU offloading for ESMFold") 154 | return parser.parse_args() 155 | 156 | def main(): 157 | args = parse_args() 158 | 159 | print("\nProtein Sequence Generation Parameters:") 160 | print(f"Model: {args.model_name}") 161 | print(f"Planner: {args.planner_name if args.planner_name else 'None'}") 162 | print(f"Number of Sequences: {args.num_seqs}") 163 | print(f"Sequence Length: {args.seq_len}") 164 | print(f"Number of Steps: {args.num_steps}") 165 | print(f"Temperature: {args.temperature}") 166 | print(f"Eta: {args.eta}") 167 | print(f"Seed: {args.seed}") 168 | print(f"Save Directory: {args.save_dir}") 169 | print(f"ESMFold Evaluation: {args.esmfold_eval}") 170 | 171 | sequences, elapsed_time = generate_sequences( 172 | model_name=args.model_name, 173 | planner_name=args.planner_name, 174 | num_seqs=args.num_seqs, 175 | seq_len=args.seq_len, 176 | num_steps=args.num_steps, 177 | temperature=args.temperature, 178 | eta=args.eta, 179 | seed=args.seed, 180 | save_dir=args.save_dir 181 | ) 182 | 183 | print(f"\nGeneration completed in {elapsed_time:.2f} seconds") 184 | print(f"Tokens/second: {args.num_seqs * args.seq_len / elapsed_time:.2f}") 185 | 186 | if args.esmfold_eval: 187 | print("\nRunning ESMFold evaluation...") 188 | save_dir = Path(args.save_dir) 189 | df = run_esmfold_eval( 190 | fasta_dir=save_dir, 191 | output_dir=save_dir / "esmfold_pdb", 192 | num_recycles=args.num_recycles, 193 | max_tokens_per_batch=args.max_tokens_per_batch, 194 | cpu_only=args.cpu_only, 195 | cpu_offload=args.cpu_offload 196 | ) 197 | 198 | if not df.empty: 199 | # Add generation metadata 200 | df['Model'] = args.model_name 201 | df['Planner'] = args.planner_name if args.planner_name else "None" 202 | df['Temperature'] = args.temperature 203 | df['Eta'] = args.eta 204 | df['Steps'] = args.num_steps 205 | df['Generation Time'] = elapsed_time 206 | 207 | # Save results 208 | results_path = save_dir / "esmfold_results.csv" 209 | df.to_csv(results_path, index=False) 210 | print(f"\nSaved ESMFold results to {results_path}") 211 | 212 | # Calculate foldability 213 | foldable_count = df[ 214 | (df['pLDDT'] > 80) & (df['pTM'] > 0.7) & (df['pAE'] < 10) 215 | ].shape[0] 216 | foldability = (foldable_count / len(df)) * 100 217 | print(f"Foldability: {foldability:.2f}%") 218 | 219 | print("\nSample sequences:") 220 | for i, seq in enumerate(sequences[:5]): 221 | print(f"Sequence {i+1}: {seq}") 222 | 223 | if __name__ == '__main__': 224 | main() -------------------------------------------------------------------------------- /examples/protein/p2_sampling_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# P2 Sampling Demo for Protein Sequence Generation\n", 8 | "\n", 9 | "This notebook demonstrates how to use P2 (Path Planning) sampling to generate protein sequences." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup\n", 17 | "\n", 18 | "First, let's import the necessary libraries:" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import torch\n", 28 | "import math\n", 29 | "import time\n", 30 | "from transformers import AutoTokenizer, EsmForMaskedLM\n", 31 | "from path_planning.p2 import p2_sampling\n", 32 | "from path_planning.utils import seed_everything\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Helper Functions\n", 40 | "\n", 41 | "Let's define some helper functions for our demo:" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def ignore_special_tokens_logits(logits, tokenizer):\n", 51 | " \"\"\"Masks out the logits of special tokens to prevent them from being sampled.\"\"\"\n", 52 | " logits[..., tokenizer.mask_token_id] = -math.inf\n", 53 | " logits[..., tokenizer._token_to_id[\"X\"]] = -math.inf\n", 54 | " logits[..., tokenizer.pad_token_id] = -math.inf\n", 55 | " logits[..., tokenizer.cls_token_id] = -math.inf\n", 56 | " logits[..., tokenizer.eos_token_id] = -math.inf\n", 57 | " return logits\n", 58 | "\n", 59 | "class ModelWrapper:\n", 60 | " \"\"\"Wrapper for the ESM model to handle logits processing.\"\"\"\n", 61 | " def __init__(self, model, tokenizer):\n", 62 | " self.model = model\n", 63 | " self.tokenizer = tokenizer\n", 64 | " \n", 65 | " def __call__(self, x):\n", 66 | " outputs = self.model(x)\n", 67 | " logits = outputs.logits\n", 68 | " return ignore_special_tokens_logits(logits.float(), self.tokenizer)\n", 69 | "\n", 70 | "def create_masked_sequence(sequence_length, tokenizer, batch_size=1, device='cuda'):\n", 71 | " \"\"\"Create a fully masked sequence for generation.\"\"\"\n", 72 | " seq = [tokenizer.mask_token] * sequence_length\n", 73 | " sequences = [''.join(seq)] * batch_size\n", 74 | " \n", 75 | " encoded = tokenizer(\n", 76 | " sequences,\n", 77 | " add_special_tokens=True,\n", 78 | " padding=True,\n", 79 | " return_tensors='pt'\n", 80 | " )\n", 81 | " return encoded['input_ids'].to(device)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## Configuration\n", 89 | "\n", 90 | "Set the parameters for protein sequence generation:" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# Configuration\n", 100 | "model_name = \"airkingbd/dplm_650m\" # You can also try \"zhangzhi/EvoFlow-150M-fs\"\n", 101 | "num_seqs = 5 # Number of sequences to generate\n", 102 | "seq_len = 100 # Length of sequences\n", 103 | "num_steps = 100 # Number of P2 sampling steps\n", 104 | "temperature = 1.0 # Sampling temperature\n", 105 | "eta = 1.0 # Stochasticity strength\n", 106 | "seed = 42 # Random seed\n", 107 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 108 | "\n", 109 | "# Set random seed for reproducibility\n", 110 | "seed_everything(seed)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## Load Model\n", 118 | "\n", 119 | "Load the protein language model:" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "print(f\"Loading model {model_name}...\")\n", 129 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 130 | "model = EsmForMaskedLM.from_pretrained(model_name)\n", 131 | "model = model.eval().to(device)\n", 132 | "\n", 133 | "# Wrap the model\n", 134 | "model_wrapper = ModelWrapper(model, tokenizer)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "## Create Initial Sequence\n", 142 | "\n", 143 | "Create a fully masked sequence as the starting point:" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "print(\"Creating initial sequence...\")\n", 153 | "xt = create_masked_sequence(\n", 154 | " sequence_length=seq_len,\n", 155 | " tokenizer=tokenizer,\n", 156 | " batch_size=num_seqs,\n", 157 | " device=device\n", 158 | ")\n", 159 | "print(f\"Initial sequence shape: {xt.shape}\")" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## Run P2 Sampling\n", 167 | "\n", 168 | "Generate protein sequences using P2 sampling:" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "print(\"Starting P2 sampling...\")\n", 178 | "start_time = time.time()\n", 179 | "# check out p2_sampling to see the full parameters\n", 180 | "sampled_xt = p2_sampling(\n", 181 | " xt=xt,\n", 182 | " model=model_wrapper,\n", 183 | " mask_id=tokenizer.mask_token_id,\n", 184 | " num_steps=num_steps,\n", 185 | " tau=temperature,\n", 186 | " eta=eta\n", 187 | ")\n", 188 | "\n", 189 | "elapsed_time = time.time() - start_time\n", 190 | "print(f\"Generation completed in {elapsed_time:.2f} seconds\")\n", 191 | "print(f\"Tokens/second: {num_seqs * seq_len / elapsed_time:.2f}\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "## Decode and Display Results\n", 199 | "\n", 200 | "Decode the generated sequences and display them:" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "# Decode sequences\n", 210 | "decoded_seqs = tokenizer.batch_decode(sampled_xt, skip_special_tokens=True)\n", 211 | "decoded_seqs = [''.join(seq.split()) for seq in decoded_seqs]\n", 212 | "\n", 213 | "# Display generated sequences\n", 214 | "print(\"\\nGenerated Protein Sequences:\")\n", 215 | "for i, seq in enumerate(decoded_seqs):\n", 216 | " print(f\"Sequence {i+1} (length {len(seq)}):\")\n", 217 | " print(seq)\n", 218 | " print()" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "## Save Sequences (Optional)\n", 226 | "\n", 227 | "Save the generated sequences to a FASTA file:" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "def save_sequences_to_fasta(sequences, seq_len, save_path):\n", 237 | " \"\"\"Save generated sequences to FASTA format.\"\"\"\n", 238 | " import os\n", 239 | " os.makedirs(os.path.dirname(save_path), exist_ok=True)\n", 240 | " with open(save_path, 'w') as fp:\n", 241 | " for idx, seq in enumerate(sequences):\n", 242 | " fp.write(f\">SEQUENCE_{idx}_L={seq_len}\\n\")\n", 243 | " fp.write(f\"{seq}\\n\")\n", 244 | "\n", 245 | "# Uncomment to save sequences\n", 246 | "# save_path = \"generated_sequences.fasta\"\n", 247 | "# save_sequences_to_fasta(decoded_seqs, seq_len, save_path)\n", 248 | "# print(f\"Saved sequences to {save_path}\")" 249 | ] 250 | } 251 | ], 252 | "metadata": { 253 | "kernelspec": { 254 | "display_name": "Python 3", 255 | "language": "python", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "codemirror_mode": { 260 | "name": "ipython", 261 | "version": 3 262 | }, 263 | "file_extension": ".py", 264 | "mimetype": "text/x-python", 265 | "name": "python", 266 | "nbconvert_exporter": "python", 267 | "pygments_lexer": "ipython3", 268 | "version": "3.10.0" 269 | } 270 | }, 271 | "nbformat": 4, 272 | "nbformat_minor": 4 273 | } 274 | -------------------------------------------------------------------------------- /examples/protein/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import typing as T 4 | from pathlib import Path 5 | import re 6 | import torch 7 | import math 8 | from collections import Counter 9 | import pandas as pd 10 | from transformers import EsmForProteinFolding 11 | from timeit import default_timer as timer 12 | import numpy as np 13 | from transformers.models.esm.openfold_utils.loss import compute_tm 14 | # Set up logging 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.INFO) 17 | formatter = logging.Formatter( 18 | "%(asctime)s | %(levelname)s | %(name)s | %(message)s", 19 | datefmt="%y/%m/%d %H:%M:%S", 20 | ) 21 | console_handler = logging.StreamHandler(sys.stdout) 22 | console_handler.setLevel(logging.INFO) 23 | console_handler.setFormatter(formatter) 24 | logger.addHandler(console_handler) 25 | 26 | def calculate_entropy(sequence: str) -> float: 27 | """Calculate Shannon entropy of a sequence.""" 28 | amino_acid_counts = Counter(sequence) 29 | total_amino_acids = len(sequence) 30 | probabilities = (count / total_amino_acids for count in amino_acid_counts.values()) 31 | return -sum(p * math.log2(p) for p in probabilities if p > 0) 32 | 33 | def read_fasta(path: str, keep_gaps: bool = True, keep_insertions: bool = True, to_upper: bool = False): 34 | """Read sequences from a FASTA file.""" 35 | with open(path, "r") as f: 36 | for result in read_alignment_lines( 37 | f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper 38 | ): 39 | yield result 40 | 41 | def read_alignment_lines(lines, keep_gaps=True, keep_insertions=True, to_upper=False): 42 | """Parse alignment lines from a FASTA file.""" 43 | seq = desc = None 44 | 45 | def parse(s): 46 | if not keep_gaps: 47 | s = re.sub("-", "", s) 48 | if not keep_insertions: 49 | s = re.sub("[a-z]", "", s) 50 | return s.upper() if to_upper else s 51 | 52 | for line in lines: 53 | if len(line) > 0 and line[0] == ">": 54 | if seq is not None and 'X' not in seq: 55 | yield desc, parse(seq) 56 | desc = line.strip().lstrip(">") 57 | seq = "" 58 | else: 59 | assert isinstance(seq, str) 60 | seq += line.strip() 61 | if seq is not None and 'X' not in seq: 62 | yield desc, parse(seq) 63 | 64 | def enable_cpu_offloading(model): 65 | """Enable CPU offloading for the model.""" 66 | from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel 67 | from torch.distributed.fsdp.wrap import enable_wrap, wrap 68 | 69 | torch.distributed.init_process_group( 70 | backend="nccl", init_method="tcp://localhost:9999", world_size=1, rank=0 71 | ) 72 | 73 | wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True)) 74 | with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs): 75 | for layer_name, layer in model.layers.named_children(): 76 | wrapped_layer = wrap(layer) 77 | setattr(model.layers, layer_name, wrapped_layer) 78 | model = wrap(model) 79 | return model 80 | 81 | def init_model_on_gpu_with_cpu_offloading(model): 82 | """Initialize model with CPU offloading.""" 83 | model = model.eval() 84 | model_esm = enable_cpu_offloading(model.esm) 85 | del model.esm 86 | model.cuda() 87 | model.esm = model_esm 88 | return model 89 | 90 | def create_batched_sequence_dataset( 91 | sequences: T.List[T.Tuple[str, str]], 92 | max_tokens_per_batch: int = 1024 93 | ) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]: 94 | """Create batched sequences for efficient processing.""" 95 | batch_headers, batch_sequences, num_tokens = [], [], 0 96 | for header, seq in sequences: 97 | if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0: 98 | yield batch_headers, batch_sequences 99 | batch_headers, batch_sequences, num_tokens = [], [], 0 100 | batch_headers.append(header) 101 | batch_sequences.append(seq) 102 | num_tokens += len(seq) 103 | if batch_headers: 104 | yield batch_headers, batch_sequences 105 | 106 | def run_esmfold_eval( 107 | fasta_dir: Path, 108 | output_dir: Path, 109 | num_recycles: int = None, 110 | max_tokens_per_batch: int = 1024, 111 | chunk_size: int = None, 112 | cpu_only: bool = False, 113 | cpu_offload: bool = False, 114 | ) -> pd.DataFrame: 115 | """ 116 | Run ESMFold evaluation on generated sequences. 117 | 118 | Args: 119 | fasta_dir: Directory containing FASTA files 120 | output_dir: Directory to save PDB files 121 | num_recycles: Number of recycles for ESMFold 122 | max_tokens_per_batch: Maximum tokens per batch 123 | chunk_size: Chunk size for axial attention 124 | cpu_only: Whether to use CPU only 125 | cpu_offload: Whether to enable CPU offloading 126 | 127 | Returns: 128 | DataFrame with evaluation results 129 | """ 130 | output_dir.mkdir(parents=True, exist_ok=True) 131 | logger.info(f"Output directory: {output_dir}") 132 | 133 | # Load ESMFold model 134 | logger.info("Loading ESMFold model...") 135 | model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").eval() 136 | model = model.eval() 137 | # model.set_chunk_size(chunk_size) 138 | 139 | # Set device 140 | if cpu_only: 141 | model.esm.float() 142 | model.cpu() 143 | elif cpu_offload: 144 | model = init_model_on_gpu_with_cpu_offloading(model) 145 | else: 146 | model.cuda() 147 | 148 | # Process FASTA files 149 | data_records = [] 150 | fasta_files = list(fasta_dir.glob("*.fasta")) 151 | logger.info(f"Found {len(fasta_files)} FASTA files") 152 | 153 | for fasta_path in fasta_files: 154 | logger.info(f"Processing {fasta_path}") 155 | pdb_dir = output_dir / fasta_path.stem 156 | pdb_dir.mkdir(exist_ok=True) 157 | 158 | # Read sequences 159 | sequences = sorted(read_fasta(str(fasta_path)), key=lambda x: len(x[1])) 160 | if not sequences: 161 | continue 162 | 163 | # Process batches 164 | for headers, seqs in create_batched_sequence_dataset(sequences, max_tokens_per_batch): 165 | try: 166 | start_time = timer() 167 | output = model.infer(seqs,) 168 | output = {k: v.cpu() for k, v in output.items()} 169 | ptm = torch.stack( 170 | [ 171 | compute_tm( 172 | batch_ptm_logits[None, :sl, :sl], 173 | max_bins=31, 174 | no_bins=64, 175 | ) 176 | for batch_ptm_logits, sl in zip(output["ptm_logits"], (len(seq) for seq in seqs)) 177 | ] 178 | ) 179 | output["ptm"] = ptm 180 | 181 | # Generate PDBs and metrics 182 | pdbs = model.output_to_pdb(output) 183 | paes = (output["aligned_confidence_probs"].numpy() * 184 | np.arange(64).reshape(1, 1, 1, 64)).mean(-1) * 31 185 | paes = paes.mean(-1).mean(-1) 186 | output["mean_plddt"] = 100 * (output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2)) 187 | # Save results 188 | for header, seq, pdb_str, plddt, ptm, pae in zip( 189 | headers, seqs, pdbs, 190 | output["mean_plddt"], output["ptm"], paes 191 | ): 192 | pdb_file = pdb_dir / f"{header}_plddt_{plddt.mean().item():.1f}_ptm_{ptm.item():.3f}_pae_{pae.item():.3f}.pdb" 193 | pdb_file.write_text(pdb_str) 194 | 195 | data_records.append({ 196 | 'FASTA_file': fasta_path.name, 197 | 'PDB_path': str(pdb_file), 198 | 'sequence': seq, 199 | 'Length': len(seq), 200 | 'pLDDT': plddt.mean().item(), 201 | 'pTM': ptm.item(), 202 | 'pAE': pae.item(), 203 | 'Entropy': calculate_entropy(seq) 204 | }) 205 | 206 | except RuntimeError as e: 207 | if "CUDA out of memory" in str(e): 208 | logger.warning(f"CUDA OOM for batch size {len(seqs)}") 209 | continue 210 | raise e 211 | 212 | # Create DataFrame 213 | df = pd.DataFrame(data_records) 214 | return df -------------------------------------------------------------------------------- /examples/text/LLaDA/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 NieShenRuc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/text/LLaDA/README.md: -------------------------------------------------------------------------------- 1 | # Large Language Diffusion Models 2 | [![arXiv](https://img.shields.io/badge/arXiv-2502.09992-red.svg)](https://arxiv.org/abs/2502.09992) 3 | [![deploy](https://img.shields.io/badge/Huggingface%20-LLaDA_Base%20-FFEB3B)](https://huggingface.co/GSAI-ML/LLaDA-8B-Base) 4 | [![deploy](https://img.shields.io/badge/Huggingface%20-LLaDA_Instruct%20-FFEB3B)](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct) 5 | [![deploy](https://img.shields.io/badge/Zhihu-知乎-blue)](https://zhuanlan.zhihu.com/p/24214732238) 6 | 7 | The code is based on [LLaDA](https://github.com/GSAI-ML/LLaDA). 8 | 9 | ## Citation 10 | 11 | ```bibtex 12 | @article{nie2025large, 13 | title={Large Language Diffusion Models}, 14 | author={Nie, Shen and Zhu, Fengqi and You, Zebin and Zhang, Xiaolu and Ou, Jingyang and Hu, Jun and Zhou, Jun and Lin, Yankai and Wen, Ji-Rong and Li, Chongxuan}, 15 | journal={arXiv preprint arXiv:2502.09992}, 16 | year={2025} 17 | } 18 | ``` 19 | 20 | -------------------------------------------------------------------------------- /examples/text/LLaDA/eval_lm_harness.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is inspired by the code from 3 | - https://github.com/ML-GSAI/SMDM 4 | - https://github.com/ML-GSAI/LLaDA/blob/main/evaluation/eval_llada.py 5 | ''' 6 | import accelerate 7 | import torch 8 | import re 9 | from pathlib import Path 10 | import random 11 | import numpy as np 12 | import torch.nn.functional as F 13 | from datasets import Dataset 14 | from lm_eval.__main__ import cli_evaluate 15 | from lm_eval.api.instance import Instance 16 | from lm_eval.api.model import LM 17 | from lm_eval.api.registry import register_model 18 | from tqdm import tqdm 19 | from path_planning.p2 import p2_sampling 20 | 21 | from transformers import AutoTokenizer, AutoModel 22 | 23 | 24 | def set_seed(seed): 25 | torch.manual_seed(seed) 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | 33 | @register_model("llada_dist") 34 | class LLaDAEvalHarness(LM): 35 | def __init__( 36 | self, 37 | model_path='', 38 | mask_id=126336, 39 | batch_size=32, 40 | mc_num=128, 41 | is_check_greedy=True, 42 | cfg=0., 43 | device="cuda", 44 | max_length=None, 45 | num_steps=None, 46 | tau=1.0, 47 | eta=1.0, 48 | ): 49 | ''' 50 | Args: 51 | model_path: LLaDA-8B-Base model path. 52 | mask_id: The token id of [MASK] is 126336. 53 | max_length: the max sequence length. 54 | batch_size: mini batch size. 55 | mc_num: Monte Carlo estimation iterations 56 | is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer 57 | is generated through greedy sampling conditioned on the prompt (note that this differs from conditional 58 | generation). We implement this verification through the suffix_greedy_prediction() function, which 59 | returns a True/False judgment used for accuracy calculation. 60 | When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function. 61 | However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality, 62 | we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False 63 | by default, significantly accelerating the evaluation process. 64 | cfg_scale: Unsupervised classifier-free guidance scale. 65 | num_steps: The number of steps for the diffusion process. 66 | tau: The temperature for the diffusion process. 67 | eta: The eta for the diffusion process. 68 | ''' 69 | super().__init__() 70 | 71 | accelerator = accelerate.Accelerator() 72 | if accelerator.num_processes > 1: 73 | self.accelerator = accelerator 74 | else: 75 | self.accelerator = None 76 | 77 | model_kwargs = {} 78 | if self.accelerator is not None: 79 | model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}}) 80 | 81 | self.is_instruct_ft = 'Instruct' in model_path 82 | self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs) 83 | self.model.eval() 84 | 85 | self.device = torch.device(device) 86 | if self.accelerator is not None: 87 | self.model = self.accelerator.prepare(self.model) 88 | self.device = torch.device(f'{self.accelerator.device}') 89 | self._rank = self.accelerator.local_process_index 90 | self._world_size = self.accelerator.num_processes 91 | else: 92 | self.model = self.model.to(device) 93 | self._rank = 0 94 | self._world_size = 1 95 | 96 | self.mask_id = mask_id 97 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 98 | 99 | self.mc_num = mc_num 100 | self.batch_size = int(batch_size) 101 | assert mc_num % self.batch_size == 0 102 | self.sampling_eps = 0. 103 | self.max_length = max_length 104 | self.is_check_greedy = is_check_greedy 105 | self.num_steps = num_steps 106 | self.tau = tau 107 | self.cfg = cfg 108 | self.eta = eta 109 | print(f'model: {model_path}') 110 | print(f'Is check greedy: {is_check_greedy}') 111 | print(f'cfg: {cfg}') 112 | print(f'num_steps: {num_steps}') 113 | print(f'tau: {tau}') 114 | print(f'eta: {eta}') 115 | if self.accelerator is not None: 116 | print(f'Running with accelerate on {self.accelerator.num_processes} processes') 117 | print(f'Local process index: {self._rank}') 118 | 119 | @property 120 | def rank(self): 121 | return self._rank 122 | 123 | @property 124 | def world_size(self): 125 | return self._world_size 126 | 127 | def _forward_process(self, batch, prompt_index): 128 | b, l = batch.shape 129 | 130 | target_len = (l - prompt_index.sum()).item() 131 | k = torch.randint(1, target_len + 1, (), device=batch.device) 132 | 133 | x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() 134 | x = ((x - 1) % target_len) + 1 135 | assert x.min() >= 1 and x.max() <= target_len 136 | 137 | indices = torch.arange(target_len, device=batch.device).repeat(b, 1) 138 | is_mask = indices < x.unsqueeze(1) 139 | 140 | for i in range(b): 141 | is_mask[i] = is_mask[i][torch.randperm(target_len)] 142 | 143 | is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1) 144 | 145 | noisy_batch = torch.where(is_mask, self.mask_id, batch) 146 | 147 | return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) 148 | 149 | @torch.no_grad() 150 | def get_logits(self, batch, prompt_index): 151 | if self.cfg > 0.: 152 | assert len(prompt_index) == batch.shape[1] 153 | prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) 154 | un_batch = batch.clone() 155 | un_batch[prompt_index] = self.mask_id 156 | batch = torch.cat([batch, un_batch]) 157 | 158 | logits = self.model(batch).logits 159 | 160 | if self.cfg > 0.: 161 | logits, un_logits = torch.chunk(logits, 2, dim=0) 162 | logits = un_logits + (self.cfg + 1) * (logits - un_logits) 163 | return logits[:, :batch.shape[1]] 164 | 165 | @torch.no_grad() 166 | def get_loglikelihood(self, prefix, target): 167 | seq = torch.concatenate([prefix, target])[None, :] 168 | seq = seq.repeat((self.batch_size, 1)).to(self.device) 169 | 170 | prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) 171 | 172 | loss_acc = [] 173 | for _ in range(self.mc_num // self.batch_size): 174 | perturbed_seq, p_mask = self._forward_process(seq, prompt_index) 175 | 176 | mask_indices = perturbed_seq == self.mask_id 177 | 178 | logits = self.get_logits(perturbed_seq, prompt_index) 179 | 180 | loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] 181 | loss = loss.sum() / self.batch_size 182 | loss_acc.append(loss.item()) 183 | 184 | return - sum(loss_acc) / len(loss_acc) 185 | 186 | @torch.no_grad() 187 | def suffix_greedy_prediction(self, prefix, target): 188 | if not self.is_check_greedy: 189 | return False 190 | 191 | seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device) 192 | prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) 193 | prefix, target = prefix.to(self.device), target.to(self.device) 194 | seq[0, :len(prefix)] = prefix 195 | 196 | for i in range(len(target)): 197 | mask_index = (seq == self.mask_id) 198 | logits = self.get_logits(seq, prompt_index)[mask_index] 199 | x0 = torch.argmax(logits, dim=-1) 200 | 201 | p = torch.softmax(logits.to(torch.float32), dim=-1) 202 | confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1) 203 | _, index = torch.sort(confidence, descending=True) 204 | x0[index[1:]] = self.mask_id 205 | seq[mask_index] = x0.clone() 206 | correct = target == seq[0, len(prefix):] 207 | correct = torch.all(correct) 208 | return correct 209 | 210 | def _encode_pair(self, context, continuation): 211 | n_spaces = len(context) - len(context.rstrip()) 212 | if n_spaces > 0: 213 | continuation = context[-n_spaces:] + continuation 214 | context = context[:-n_spaces] 215 | 216 | whole_enc = self.tokenizer(context + continuation)["input_ids"] 217 | context_enc = self.tokenizer(context)["input_ids"] 218 | 219 | context_enc_len = len(context_enc) 220 | continuation_enc = whole_enc[context_enc_len:] 221 | 222 | return context_enc, continuation_enc 223 | 224 | def loglikelihood(self, requests): 225 | def _tokenize(e): 226 | prefix, target = self._encode_pair(e["prefix"], e["target"]) 227 | return { 228 | "prefix_text": e["prefix"], 229 | "target_text": e["target"], 230 | "prefix": prefix, 231 | "target": target, 232 | } 233 | 234 | ds = [] 235 | ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] 236 | ds = Dataset.from_list(ds) 237 | ds = ds.map(_tokenize) 238 | ds = ds.with_format("torch") 239 | prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] 240 | 241 | assert max(prompt_len) <= 4096 242 | 243 | out = [] 244 | with torch.no_grad(): 245 | for elem in tqdm(ds, desc="Computing likelihood..."): 246 | prefix = elem["prefix"] 247 | target = elem["target"] 248 | 249 | ll = self.get_loglikelihood(prefix, target) 250 | 251 | is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) 252 | 253 | out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) 254 | print('=' * 20) 255 | print('prefix: ', elem['prefix_text']) 256 | print('target: ', elem['target_text']) 257 | print(ll, is_target_greedy_dec) 258 | print('=' * 20, end='\n\n') 259 | torch.cuda.empty_cache() 260 | return out 261 | 262 | def loglikelihood_rolling(self, requests): 263 | raise NotImplementedError 264 | def generate_until(self, requests,): 265 | """ 266 | Generates text continuations for a list of requests using the masked diffusion approach, 267 | processing requests in batches of size `small_batch_size`. 268 | 269 | Each request should have: 270 | - args[0]: A prompt string. 271 | - args[1]: A dictionary of generation parameters (e.g., {"max_gen_toks": 128, "until": ["stopword1", ...]}). 272 | 273 | Returns: 274 | A list of generated text strings. 275 | """ 276 | batch_size = self.batch_size 277 | outputs = [] 278 | 279 | # Process requests in batches of small_batch_size 280 | for batch_start in range(0, len(requests), batch_size): 281 | batch_requests = requests[batch_start : batch_start + batch_size] 282 | 283 | prompts = [] 284 | stop_words_list = [] 285 | max_gen_toks_list = [] 286 | 287 | # Process each request in the current batch to extract prompt and per-request parameters 288 | for req in batch_requests: 289 | prompt_text = req.args[0] 290 | gen_params = req.args[1] if len(req.args) > 1 else {} 291 | max_gen_toks = self.max_length if self.max_length is not None else gen_params.get("max_gen_toks", self.max_length) 292 | stop_words = gen_params.get("until", []) 293 | 294 | # For instruct finetuned models, apply the chat template 295 | if self.is_instruct_ft: 296 | m = [{"role": "user", "content": prompt_text}] 297 | prompt_text = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) 298 | 299 | prompts.append(prompt_text) 300 | stop_words_list.append(stop_words) 301 | max_gen_toks_list.append(max_gen_toks) 302 | 303 | # Tokenize all prompts together; use padding to create a uniform tensor shape 304 | tokenized = self.tokenizer(prompts, padding=True, return_tensors="pt") 305 | input_ids = tokenized["input_ids"].to(self.device) # shape: (batch_size, max_prompt_length) 306 | batch_size, max_prompt_length = input_ids.shape 307 | prompt_lengths = tokenized["attention_mask"].sum(dim=1) # shape: (batch_size,) 308 | 309 | # Determine the maximum generation tokens for this batch 310 | max_gen_toks_batch = max(max_gen_toks_list) 311 | 312 | # Create a batched tensor filled with mask_id with room for prompt and generated tokens 313 | xt = torch.full((batch_size, max_prompt_length + max_gen_toks_batch), 314 | self.mask_id, dtype=torch.long, device=self.device) 315 | 316 | # Copy each prompt into its corresponding row in the batched tensor 317 | for i in range(batch_size): 318 | prompt_len = prompt_lengths[i].item() 319 | xt[i, :prompt_len] = input_ids[i, :prompt_len] 320 | 321 | # Use a common number of steps (using self.num_steps if set, else use max_gen_toks_batch) 322 | num_steps = self.num_steps if self.num_steps is not None else max_gen_toks_batch 323 | tau = self.tau 324 | eta = self.eta 325 | 326 | # Define a batched model wrapper that returns logits for the entire batch 327 | def model_wrapper(x): 328 | outputs = self.model(x) 329 | return outputs.logits 330 | 331 | # Run the diffusion-based sampling on the current batch 332 | sampled_xt = p2_sampling( 333 | xt=xt, 334 | model=model_wrapper, 335 | mask_id=self.mask_id, 336 | num_steps=num_steps, 337 | tau=tau, 338 | eta=eta, 339 | ) 340 | 341 | # Process each sample in the current batch: extract generated tokens and decode them 342 | for i in range(batch_size): 343 | prompt_len = prompt_lengths[i].item() 344 | gen_limit = max_gen_toks_list[i] 345 | gen_tokens = sampled_xt[i, prompt_len:prompt_len + gen_limit] 346 | generated_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True) 347 | 348 | # Truncate generated text by any specified stop words 349 | for word in stop_words_list[i]: 350 | if word in generated_text: 351 | generated_text = generated_text[:generated_text.index(word)] 352 | 353 | print(generated_text) 354 | print('=' * 20, end='\n\n') 355 | outputs.append(generated_text) 356 | 357 | return outputs 358 | 359 | 360 | if __name__ == "__main__": 361 | set_seed(1234) 362 | cli_evaluate() 363 | -------------------------------------------------------------------------------- /examples/text/LLaDA/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | from transformers import AutoTokenizer, AutoModel 6 | from path_planning.p2 import * 7 | from path_planning.utils import seed_everything 8 | from LLaDA_hf.modeling_llada import LLaDAModelLM 9 | 10 | class ModelWrapper: 11 | """Wrapper for the model to handle logits processing.""" 12 | def __init__(self, model): 13 | self.model = model 14 | 15 | def __call__(self, x): 16 | outputs = self.model(x) 17 | return outputs.logits 18 | 19 | 20 | def main(): 21 | device = 'cuda' 22 | seed_everything(42) # For reproducibility 23 | 24 | model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', torch_dtype=torch.bfloat16).to(device).eval() 25 | tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) 26 | model_wrapper = ModelWrapper(model) 27 | 28 | prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" 29 | 30 | # Add special tokens for the Instruct model. The Base model does not require the following two lines. 31 | m = [{"role": "user", "content": prompt}, ] 32 | prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) 33 | 34 | input_ids = tokenizer(prompt)['input_ids'] 35 | input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) 36 | 37 | gen_length = 128 38 | mask_id = 126336 39 | xt = torch.full((1, input_ids.shape[1] + gen_length), mask_id, dtype=torch.long).to(device) 40 | xt[:, :input_ids.shape[1]] = input_ids.clone() 41 | 42 | 43 | print(f"Input shape: {input_ids.shape}, Full sequence shape: {xt.shape}") 44 | 45 | 46 | sampled_xt = p2_plus_sampling( 47 | xt=xt, 48 | model=model_wrapper, 49 | mask_id=mask_id, 50 | num_steps=128, 51 | tau=0., 52 | ) 53 | print(f'prompt: {prompt}') 54 | print(f'generated: {tokenizer.batch_decode(sampled_xt[:, input_ids.shape[1]:], skip_special_tokens=True)[0]}') 55 | 56 | 57 | if __name__ == '__main__': 58 | main() -------------------------------------------------------------------------------- /examples/text/LLaDA/scripts/gsm8k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_ALLOW_CODE_EVAL=1 4 | export CUDA_VISIBLE_DEVICES=1 5 | 6 | MODEL_NAME="llada_dist" 7 | MODEL_PATH="GSAI-ML/LLaDA-8B-Base" 8 | TASK="gsm8k" 9 | BATCH_SIZE=12 10 | NUM_STEPS=256 11 | TAU=0.0 12 | MAX_LENGTH=256 13 | 14 | for eta in $(seq 0 0.2 2.0); do 15 | eta_str=$(printf "%.1f" $eta) 16 | OUTPUT_PATH="./results/gsm8k/" 17 | echo "Running with eta=${eta_str}, saving to ${OUTPUT_PATH}" 18 | 19 | python eval_lm_harness.py \ 20 | --tasks ${TASK} \ 21 | --model ${MODEL_NAME} \ 22 | --confirm_run_unsafe_code \ 23 | --batch_size ${BATCH_SIZE} \ 24 | --output_path ${OUTPUT_PATH} \ 25 | --model_args model_path="${MODEL_PATH}",mc_num=${BATCH_SIZE},num_steps=${NUM_STEPS},tau=${TAU},max_length=${MAX_LENGTH},eta=${eta_str} 26 | done -------------------------------------------------------------------------------- /examples/text/LLaDA/scripts/gsm8k_instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_ALLOW_CODE_EVAL=1 4 | export CUDA_VISIBLE_DEVICES=2 5 | 6 | MODEL_NAME="llada_dist" 7 | MODEL_PATH="GSAI-ML/LLaDA-8B-Instruct" 8 | TASK="gsm8k" 9 | BATCH_SIZE=12 10 | NUM_STEPS=256 11 | TAU=0.0 12 | MAX_LENGTH=256 13 | 14 | for eta in $(seq 0 0.2 2.0); do 15 | eta_str=$(printf "%.1f" $eta) 16 | OUTPUT_PATH="./results/gsm8k_instruct" 17 | echo "Running with eta=${eta_str}, saving to ${OUTPUT_PATH}" 18 | 19 | python eval_lm_harness.py \ 20 | --tasks ${TASK} \ 21 | --model ${MODEL_NAME} \ 22 | --confirm_run_unsafe_code \ 23 | --batch_size ${BATCH_SIZE} \ 24 | --output_path ${OUTPUT_PATH} \ 25 | --model_args model_path="${MODEL_PATH}",mc_num=${BATCH_SIZE},num_steps=${NUM_STEPS},tau=${TAU},max_length=${MAX_LENGTH},eta=${eta_str} 26 | done -------------------------------------------------------------------------------- /examples/text/LLaDA/scripts/humaneval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_ALLOW_CODE_EVAL=1 4 | 5 | MODEL_NAME="llada_dist" 6 | MODEL_PATH="GSAI-ML/LLaDA-8B-Base" 7 | TASK="humaneval" 8 | BATCH_SIZE=16 9 | MC_NUM=$BATCH_SIZE 10 | NUM_STEPS=100 11 | TAU=1.0 12 | MAX_LENGTH=300 13 | 14 | GPUS=(0 3) 15 | 16 | i=0 17 | for eta in $(seq 0 0.2 2.0); do 18 | eta_str=$(printf "%.1f" "$eta") 19 | GPU=${GPUS[$(( i % 2 ))]} 20 | OUTPUT_PATH="./results/humaneval/" 21 | 22 | echo "Running eta=${eta_str} on GPU ${GPU}, saving to ${OUTPUT_PATH}" 23 | CUDA_VISIBLE_DEVICES="${GPU}" python eval_lm_harness.py \ 24 | --tasks "${TASK}" \ 25 | --model "${MODEL_NAME}" \ 26 | --confirm_run_unsafe_code \ 27 | --batch_size "${BATCH_SIZE}" \ 28 | --output_path "${OUTPUT_PATH}" \ 29 | --model_args model_path="${MODEL_PATH}",mc_num="${MC_NUM}",num_steps="${NUM_STEPS}",tau="${TAU}",max_length="${MAX_LENGTH}",eta="${eta_str}" & 30 | 31 | # After launching a job on GPU 3 (odd index), wait for both GPUs to free up 32 | if (( i % 2 == 1 )); then 33 | wait 34 | fi 35 | 36 | ((i++)) 37 | done 38 | 39 | # Wait for any remaining background jobs 40 | wait 41 | 42 | echo "All evaluations complete." -------------------------------------------------------------------------------- /examples/text/LLaDA/scripts/test_LLaDA_code.sh: -------------------------------------------------------------------------------- 1 | export HF_ALLOW_CODE_EVAL=1 2 | 3 | #Args for NVIDIA A100-PCIE-40GB 4 | 5 | #humaneval 6 | CUDA_VISIBLE_DEVICES=0,1,2,7 \ 7 | HF_ALLOW_CODE_EVAL=1 \ 8 | accelerate launch eval_lm_harness.py \ 9 | --tasks humaneval \ 10 | --model llada_dist \ 11 | --confirm_run_unsafe_code \ 12 | --batch_size 4 \ 13 | --output_path /data/shuibai/LLaDA/results/humaneval/ \ 14 | --log_samples \ 15 | --model_args model_path="GSAI-ML/LLaDA-8B-Base",mc_num=12,num_steps=512,tau=0.0,max_length=512,eta=1.0 \ 16 | # --limit 8 17 | 18 | 19 | #mbpp 20 | CUDA_VISIBLE_DEVICES=0,1,2,7 \ 21 | HF_ALLOW_CODE_EVAL=1 \ 22 | accelerate launch eval_lm_harness.py \ 23 | --tasks mbpp \ 24 | --model llada_dist \ 25 | --confirm_run_unsafe_code \ 26 | --batch_size 4 \ 27 | --output_path /data/shuibai/LLaDA/results/mbpp/ \ 28 | --log_samples \ 29 | --model_args model_path="GSAI-ML/LLaDA-8B-Base",mc_num=12,num_steps=512,tau=0.0,max_length=512,eta=1.0 \ 30 | # --limit 8 31 | 32 | 33 | 34 | cp -r /data/shuibai/LLaDA/results /u/s/h/shuibai/Path-Planning/examples/text/LLaDA/results-cp 35 | 36 | 37 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="path_planning", 5 | version="0.1.0", 6 | packages=find_packages(where="src"), 7 | package_dir={"": "src"}, 8 | install_requires=[ 9 | "torch", 10 | "tqdm", 11 | "transformers", 12 | "numpy", 13 | "pandas", 14 | ], 15 | author="Fred Zhangzhi Peng", 16 | author_email="zp70@duke.edu", 17 | description="P2 (Path Planning) sampling implementation for sequence generation", 18 | long_description=open("README.md").read(), 19 | long_description_content_type="text/markdown", 20 | url="https://github.com/pengzhangzhi/path_planning", 21 | classifiers=[ 22 | "Development Status :: 4 - Beta", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: MIT License", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 30 | ], 31 | keywords="diffusion, sequence generation, masked language model, path planning", 32 | python_requires=">=3.8", 33 | ) -------------------------------------------------------------------------------- /src/path_planning/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Path Planning (P2) Sampling 3 | 4 | A Python package implementing P2 (Path Planning) sampling, a guided diffusion method for sequence generation. 5 | """ 6 | 7 | # Core sampling functions 8 | from .p2 import ( 9 | p2_sampling, 10 | ancestral_sampling, 11 | greedy_ancestral_sampling, 12 | dfm_sampling 13 | ) 14 | 15 | # Utility functions 16 | from .utils import ( 17 | seed_everything, 18 | topk_lowest_masking, 19 | topk_highest_masking, 20 | stochastic_sample_from_categorical 21 | ) 22 | 23 | # Score functions 24 | from .score_function import ( 25 | logP, 26 | random_score, 27 | diff_top2 28 | ) 29 | 30 | # Scheduler functions 31 | from .scheduler import ( 32 | linear_scheduler, 33 | sine_scheduler, 34 | geometric_scheduler, 35 | log_scheduler, 36 | poly2_scheduler, 37 | poly05_scheduler 38 | ) 39 | 40 | # Define what should be available directly upon import 41 | __all__ = [ 42 | # Core sampling 43 | "p2_sampling", 44 | "ancestral_sampling", 45 | "greedy_ancestral_sampling", 46 | "dfm_sampling", 47 | 48 | # Utilities 49 | "seed_everything", 50 | "topk_lowest_masking", 51 | "topk_highest_masking", 52 | "stochastic_sample_from_categorical", 53 | 54 | # Score functions 55 | "logP", 56 | "random_score", 57 | "diff_top2", 58 | 59 | # Schedulers 60 | "linear_scheduler", 61 | "sine_scheduler", 62 | "geometric_scheduler", 63 | "log_scheduler", 64 | "poly2_scheduler", 65 | "poly05_scheduler" 66 | ] 67 | -------------------------------------------------------------------------------- /src/path_planning/p2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from typing import Optional, Any, Callable, Union 4 | from path_planning.utils import topk_lowest_masking, stochastic_sample_from_categorical 5 | from path_planning.score_function import logP, random_score 6 | 7 | 8 | @torch.inference_mode() 9 | @torch.cuda.amp.autocast() 10 | def p2_sampling( 11 | xt: torch.Tensor, 12 | model: Any, 13 | mask_id: int, 14 | num_steps: int, 15 | tau: float = 1.0, 16 | kappa_fn: Callable[[float], float] = lambda t: t, 17 | eta: float = 1.0, 18 | planner: Optional[Any] = None, 19 | score_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = logP 20 | ) -> torch.Tensor: 21 | """ 22 | P2 (Path Planning) sampling implementation. 23 | 24 | This function implements the P2 sampling algorithm, a guided diffusion method 25 | for sequence generation. It starts with a fully or partially masked sequence and 26 | progressively unmasks tokens based on model confidence and a scheduling function. 27 | 28 | Algorithm: 29 | 30 | 1. Start with a masked sequence x_T 31 | 2. For each step t from T to 1: 32 | - Forward pass: compute logits = model(x_t) 33 | - Sample: x_0 ~ softmax(logits/τ) 34 | - Compute scores using score_fn 35 | - Modify scores for tokens that were previously unmasked (multiplied by η) 36 | - Compute κ_t = kappa_fn(t/T) to determine the fraction of tokens to keep unmasked 37 | - Mask the lowest-scoring tokens to maintain κ_t fraction unmasked 38 | - Replace masked tokens with sampled ones 39 | 3. Return the final sequence 40 | 41 | Args: 42 | xt: Input tensor with masked tokens, shape (batch_size, seq_len) 43 | model: Main model for generating logits. Should return logits when called with a sequence 44 | mask_id: ID of the mask token 45 | num_steps: Number of sampling steps 46 | tau: Temperature parameter for sampling. Higher values increase diversity 47 | kappa_fn: Function to compute kappa at each timestep t ∈ [0,1]. Determines unmasking schedule 48 | eta: stochasticity strength, higher values make the sampling more stochastic. 49 | planner: Optional planner model for guided generation 50 | score_fn: Scoring function for token selection (e.g., logP, random_score, diff_top2) 51 | 52 | Returns: 53 | Sampled sequence tensor of shape (batch_size, seq_len) 54 | 55 | References: 56 | - P2 Sampling: https://arxiv.org/pdf/2502.03540 57 | """ 58 | dt = 1/num_steps 59 | fix_mask = xt != mask_id 60 | 61 | for i in tqdm(range(1, num_steps+1)): 62 | kappa_t = kappa_fn(i*dt) 63 | logits = model(xt).double() 64 | last_mask = xt == mask_id 65 | unmask_t = ~last_mask & ~fix_mask 66 | 67 | x0, logp, logits = stochastic_sample_from_categorical(logits, temperature=tau) 68 | if planner is not None: 69 | planner_logits = planner(x0).double() 70 | planner_logp = planner_logits.log_softmax(dim=-1).gather(-1, x0.unsqueeze(-1)).squeeze(-1) 71 | logits[unmask_t] = planner_logits[unmask_t] 72 | logp[unmask_t] = planner_logp[unmask_t] 73 | score = score_fn(logits, x0) 74 | score = score.masked_fill(fix_mask, float('inf')) 75 | 76 | score[unmask_t] = score[unmask_t] * eta 77 | 78 | num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long() 79 | lowest_k_mask = topk_lowest_masking(score, num_to_mask) 80 | to_mask = lowest_k_mask 81 | 82 | xt[to_mask] = mask_id 83 | mask_2_x0 = last_mask & ~lowest_k_mask 84 | xt[mask_2_x0] = x0[mask_2_x0] 85 | # Fill any remaining masks 86 | xt[xt == mask_id] = x0[xt == mask_id] 87 | return xt 88 | 89 | import torch 90 | from typing import Any 91 | 92 | import torch 93 | from typing import Any 94 | import math 95 | 96 | @torch.inference_mode() 97 | @torch.cuda.amp.autocast() 98 | def p2_plus_sampling( 99 | xt: torch.Tensor, 100 | model: Any, 101 | mask_id: int, 102 | num_steps: int | None = None, 103 | tau: float = 1.0, 104 | **kwargs: Any 105 | ) -> torch.Tensor: 106 | """ 107 | P2+ sampling with fixed decoding order. 108 | Assumes all samples in the batch have: 109 | - identical initial xt 110 | - the same number of masked tokens 111 | - the same decoding order 112 | 113 | Args: 114 | xt: Tensor of shape (B, L) with masked tokens. 115 | model: Callable that outputs logits from xt. 116 | mask_id: Token ID used for masking. 117 | num_steps: Number of decoding iterations. If None, decode one token per step. 118 | tau: Temperature for sampling. 119 | 120 | Returns: 121 | xt: Fully decoded tensor of shape (B, L). 122 | """ 123 | B, L = xt.shape 124 | device = xt.device 125 | 126 | # 1. Identify masked positions — shared across all samples 127 | mask_positions = (xt[0] == mask_id) # (L,) 128 | num_masks = mask_positions.sum().item() 129 | if num_masks == 0: 130 | return xt 131 | 132 | # 2. Validate or infer number of decoding steps 133 | if num_steps is None: 134 | num_steps = num_masks 135 | elif not (1 <= num_steps <= num_masks): 136 | raise ValueError( 137 | f"num_steps must be in [1, {num_masks}], got {num_steps}" 138 | ) 139 | 140 | # 3. Initial forward pass and decoding order 141 | logits = model(xt).double() # (B, L, V) 142 | x0, scores, _ = stochastic_sample_from_categorical(logits, temperature=tau) 143 | scores = scores[0] # (L,) 144 | masked_scores = torch.where(mask_positions, scores, torch.tensor(-float('inf'), device=device)) 145 | decoding_order = torch.argsort(masked_scores, descending=True) # (L,) 146 | masked_indices = decoding_order[:num_masks] # Only the masked ones, sorted 147 | 148 | # 4. Compute per-step split 149 | step_size = math.ceil(num_masks / num_steps) 150 | step_ranges = [ 151 | masked_indices[i * step_size : min((i + 1) * step_size, num_masks)] 152 | for i in range(num_steps) 153 | ] 154 | 155 | # 5. Iterative decoding (shared across batch) 156 | batch_idx = torch.arange(B, device=device).unsqueeze(1) # (B, 1) 157 | for pos_ids in step_ranges: 158 | if len(pos_ids) == 0: 159 | continue 160 | logits = model(xt).double() 161 | x0, _, _ = stochastic_sample_from_categorical(logits, temperature=tau) 162 | pos_ids = pos_ids.to(device) 163 | pos_ids_expand = pos_ids.unsqueeze(0).expand(B, -1) # (B, K) 164 | xt[batch_idx.expand_as(pos_ids_expand), pos_ids_expand] = x0[batch_idx.expand_as(pos_ids_expand), pos_ids_expand] 165 | 166 | return xt 167 | 168 | from functools import partial 169 | 170 | ancestral_sampling = partial( 171 | p2_sampling, 172 | planner=None, 173 | score_fn=random_score, 174 | eta=0 175 | ) 176 | ancestral_sampling.__doc__ = """ 177 | Ancestral sampling using the P2 framework. 178 | 179 | This is a specialized version of P2 sampling that uses random scores and no 180 | eta parameter, resulting in a pure diffusion sampling approach where token 181 | selection is random rather than based on model confidence. 182 | 183 | Args: 184 | Same as p2_sampling, except: 185 | - planner is fixed to None 186 | - score_fn is fixed to random_score 187 | - eta is fixed to 0 188 | 189 | Returns: 190 | Sampled sequence tensor of shape (batch_size, seq_len) 191 | """ 192 | 193 | 194 | greedy_ancestral_sampling = partial( 195 | p2_sampling, 196 | planner=None, 197 | score_fn=logP, 198 | eta=1, 199 | ) 200 | greedy_ancestral_sampling.__doc__ = """ 201 | Greedy ancestral sampling using the P2 framework. 202 | 203 | This variant uses log probabilities as scores but has no planner model. 204 | It selects tokens based on model confidence, making it more deterministic 205 | than pure ancestral sampling. 206 | 207 | Args: 208 | Same as p2_sampling, except: 209 | - planner is fixed to None 210 | - score_fn is fixed to logP 211 | - eta is fixed to 1 212 | 213 | Returns: 214 | Sampled sequence tensor of shape (batch_size, seq_len) 215 | """ 216 | 217 | 218 | dfm_sampling = partial( 219 | p2_sampling, 220 | planner=None, 221 | score_fn=random_score, 222 | ) 223 | dfm_sampling.__doc__ = """ 224 | Diffusion Masked Language Model (DFM) sampling. 225 | 226 | This is a variant of P2 sampling aligned with the DFM approach, 227 | using random scores but no planner model. 228 | 229 | Args: 230 | Same as p2_sampling, except: 231 | - planner is fixed to None 232 | - score_fn is fixed to random_score 233 | 234 | Returns: 235 | Sampled sequence tensor of shape (batch_size, seq_len) 236 | """ 237 | 238 | 239 | 240 | ancestral_sampling = partial( 241 | dfm_sampling, 242 | eta=0, 243 | ) 244 | -------------------------------------------------------------------------------- /src/path_planning/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Time-dependent schedulers for P2 sampling. 3 | 4 | The scheduler is a non-decreasing function that takes a time step t [0, 1] and returns a scalar [0, 1]. 5 | f(0) = 0, f(1) = 1 6 | 7 | These schedulers control the rate at which tokens are unmasked during P2 sampling. Different 8 | schedulers produce different dynamics in the sampling process, potentially affecting the quality 9 | of the generated sequences. 10 | """ 11 | import math 12 | import torch 13 | import matplotlib.pyplot as plt 14 | from typing import List, Callable, Union, Optional 15 | 16 | 17 | def linear_scheduler(t: float) -> float: 18 | """ 19 | Linear scheduler that increases proportionally with time. 20 | 21 | This is the simplest scheduler, where the proportion of tokens unmasked 22 | increases linearly with time. 23 | 24 | Mathematical formulation: 25 | f(t) = t 26 | 27 | Args: 28 | t: Time step in [0, 1] 29 | 30 | Returns: 31 | Float in [0, 1] representing progress of the unmasking process 32 | """ 33 | return t 34 | 35 | 36 | def sine_scheduler(t: float) -> float: 37 | """ 38 | Sine scheduler that maps t from [0,1] to [0,1] with a smooth S-curve. 39 | 40 | This scheduler starts slow, accelerates in the middle, and then slows down 41 | at the end, following a sine curve. It provides a smooth transition between 42 | the fully masked and fully unmasked states. 43 | 44 | Mathematical formulation: 45 | f(t) = sin(t * π/2) 46 | 47 | Args: 48 | t: Time step in [0, 1] 49 | 50 | Returns: 51 | Float in [0, 1] representing progress of the unmasking process 52 | """ 53 | return math.sin(t * math.pi / 2) 54 | 55 | 56 | def geometric_scheduler(t: float) -> float: 57 | """ 58 | Geometric scheduler that maps t from [0,1] to [0,1] with accelerating progress. 59 | 60 | This scheduler starts slow and accelerates, following a quadratic curve. 61 | It tends to keep more tokens masked at the beginning compared to linear. 62 | 63 | Mathematical formulation: 64 | f(t) = 1-(1-t)² 65 | 66 | Args: 67 | t: Time step in [0, 1] 68 | 69 | Returns: 70 | Float in [0, 1] representing progress of the unmasking process 71 | """ 72 | return 1-(1-t)**2 73 | 74 | 75 | def log_scheduler(t: float) -> float: 76 | """ 77 | Logarithmic scheduler that maps t from [0,1] to [0,1]. 78 | 79 | This scheduler progresses quickly at the beginning and slows down toward the end. 80 | It tends to unmask more tokens early in the process. 81 | 82 | Mathematical formulation: 83 | f(t) = log(t+1)/log(2) 84 | 85 | Args: 86 | t: Time step in [0, 1] 87 | 88 | Returns: 89 | Float in [0, 1] representing progress of the unmasking process 90 | """ 91 | return math.log(t+1)/math.log(2) 92 | 93 | 94 | def poly2_scheduler(t: float) -> float: 95 | """ 96 | Polynomial (quadratic) scheduler that maps t from [0,1] to [0,1]. 97 | 98 | This scheduler starts slow and accelerates, following a quadratic curve. 99 | It's similar to the geometric scheduler but with different dynamics. 100 | 101 | Mathematical formulation: 102 | f(t) = t² 103 | 104 | Args: 105 | t: Time step in [0, 1] 106 | 107 | Returns: 108 | Float in [0, 1] representing progress of the unmasking process 109 | """ 110 | return t**2 111 | 112 | 113 | def poly05_scheduler(t: float) -> float: 114 | """ 115 | Polynomial (square root) scheduler that maps t from [0,1] to [0,1]. 116 | 117 | This scheduler starts fast and slows down, following a square root curve. 118 | It tends to unmask more tokens early in the process. 119 | 120 | Mathematical formulation: 121 | f(t) = t^0.5 122 | 123 | Args: 124 | t: Time step in [0, 1] 125 | 126 | Returns: 127 | Float in [0, 1] representing progress of the unmasking process 128 | """ 129 | return t**0.5 130 | 131 | 132 | if __name__ == "__main__": 133 | def plot_schedulers(t_max: int = 100) -> None: 134 | """ 135 | Plot all scheduler functions for comparison. 136 | 137 | Args: 138 | t_max: Number of time steps to plot 139 | """ 140 | t = torch.linspace(0, 1, t_max) 141 | plt.plot(t, [linear_scheduler(float(i)) for i in t], label='linear') 142 | plt.plot(t, [sine_scheduler(float(i)) for i in t], label='sine') 143 | plt.plot(t, [geometric_scheduler(float(i)) for i in t], label='geometric') 144 | plt.plot(t, [poly2_scheduler(float(i)) for i in t], label='poly2') 145 | plt.plot(t, [poly05_scheduler(float(i)) for i in t], label='poly0.5') 146 | plt.plot(t, [log_scheduler(float(i)) for i in t], label='log') 147 | plt.legend() 148 | plt.savefig('schedulers.png') 149 | 150 | plot_schedulers() 151 | -------------------------------------------------------------------------------- /src/path_planning/score_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Callable, Tuple 3 | 4 | 5 | def logP(logits: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Compute the log probabilities of predicted tokens from model logits. 8 | 9 | This function calculates the log probability of each token in x0 according to the model's 10 | predicted distribution. It's a common scoring function used in P2 sampling to determine 11 | the confidence of the model in each position. 12 | 13 | Mathematical formulation: 14 | logP(x_0) = log(p(x_0 | logits)) = log_softmax(logits)[x_0] 15 | 16 | Args: 17 | logits: Tensor of shape (batch_size, seq_len, vocab_size) - Raw logits from model 18 | x0: Tensor of shape (batch_size, seq_len) - Token indices to compute probabilities for 19 | 20 | Returns: 21 | scores: Tensor of shape (batch_size, seq_len) - Log probabilities for each token 22 | """ 23 | logits = logits.double() 24 | logits = logits.log_softmax(dim=-1) 25 | scores = logits.gather(dim=-1, index=x0.unsqueeze(-1)).squeeze(-1) 26 | return scores 27 | 28 | 29 | def random_score(logits: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Return a random score for each token. 32 | 33 | This function generates random scores irrespective of the logits or token values. 34 | It's used for pure diffusion sampling where token selection is random rather than 35 | based on model confidence. 36 | 37 | Mathematical formulation: 38 | score(x) = log(rand(0, 1)) 39 | 40 | Args: 41 | logits: Tensor of shape (batch_size, seq_len, vocab_size) - Raw logits from model (unused) 42 | x0: Tensor of shape (batch_size, seq_len) - Token indices (used only for shape) 43 | 44 | Returns: 45 | Tensor of shape (batch_size, seq_len) - Random log scores for each position 46 | """ 47 | return torch.rand_like(x0.float()).log() 48 | 49 | 50 | def diff_top2(logits: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: 51 | """ 52 | Compute the difference between the top 2 probabilities for each position. 53 | 54 | This scoring function measures model confidence as the gap between the most likely 55 | and second most likely tokens at each position. A larger difference indicates higher 56 | confidence. 57 | 58 | Mathematical formulation: 59 | score(x) = log_softmax(logits)[top_1] - log_softmax(logits)[top_2] 60 | 61 | Args: 62 | logits: Tensor of shape (batch_size, seq_len, vocab_size) - Raw logits from model 63 | x0: Tensor of shape (batch_size, seq_len) - Token indices (unused in this function) 64 | 65 | Returns: 66 | Tensor of shape (batch_size, seq_len) - Difference between top 2 log probabilities 67 | """ 68 | logits = logits.log_softmax(dim=-1) 69 | top2_logits = logits.topk(2, dim=-1).values 70 | diff = top2_logits[:, :, 0] - top2_logits[:, :, 1] 71 | return diff 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /src/path_planning/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from typing import Optional, Tuple, Union 5 | 6 | def topk_lowest_masking(scores: torch.Tensor, cutoff_len: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Creates a mask identifying the k lowest-scoring positions in a tensor. 9 | 10 | This function selects positions with the lowest scores, where k is specified by cutoff_len. 11 | It's used in P2 sampling to determine which positions should be masked. 12 | 13 | Args: 14 | scores: Tensor of shape (batch_size, seq_len) containing scores for each position 15 | cutoff_len: Tensor of shape (batch_size, 1) specifying how many positions to select in each sequence 16 | 17 | Returns: 18 | Boolean mask tensor of shape (batch_size, seq_len) where True indicates positions with lowest scores 19 | """ 20 | sorted_index = scores.sort(-1)[0] 21 | cutoff = sorted_index.gather(dim=-1, index=cutoff_len) 22 | masking = scores < cutoff 23 | return masking 24 | 25 | def topk_highest_masking(scores: torch.Tensor, cutoff_len: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Creates a mask identifying the k highest-scoring positions in a tensor. 28 | 29 | This function selects positions with the highest scores, where k is specified by cutoff_len. 30 | It's the opposite of topk_lowest_masking and can be used when you want to select the most confident tokens. 31 | 32 | Args: 33 | scores: Tensor of shape (batch_size, seq_len) containing scores for each position 34 | cutoff_len: Tensor of shape (batch_size, 1) specifying how many positions to select in each sequence 35 | 36 | Returns: 37 | Boolean mask tensor of shape (batch_size, seq_len) where True indicates positions with highest scores 38 | """ 39 | sorted_index = scores.sort(-1, descending=True)[0] 40 | cutoff = sorted_index.gather(dim=-1, index=cutoff_len) 41 | masking = scores >= cutoff 42 | return masking 43 | 44 | def seed_everything(seed: Optional[int] = None) -> None: 45 | """ 46 | Set the seed for reproducibility across various libraries. 47 | 48 | This function sets random seeds for Python's random module, NumPy, and PyTorch 49 | to ensure reproducible results across runs. 50 | 51 | Args: 52 | seed: Integer seed value. If None, no seeding is performed. 53 | """ 54 | if seed is None: 55 | return 56 | random.seed(seed) 57 | np.random.seed(seed) 58 | torch.manual_seed(seed) 59 | if torch.cuda.is_available(): 60 | torch.cuda.manual_seed(seed) 61 | torch.cuda.manual_seed_all(seed) # if using multi-GPU 62 | torch.backends.cudnn.deterministic = True 63 | torch.backends.cudnn.benchmark = False 64 | 65 | def stochastic_sample_from_categorical( 66 | logits: torch.Tensor, 67 | temperature: float = 1.0, 68 | noise_scale: float = 1.0 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """ 71 | Sample from a categorical distribution with temperature scaling and Gumbel noise. 72 | 73 | This function implements stochastic sampling from logits with temperature control. 74 | When temperature > 0, Gumbel noise is added to introduce randomness in sampling. 75 | 76 | Mathematical formulation: 77 | 1. Convert logits to probabilities: p = softmax(logits/temperature + noise_scale * gumbel_noise) 78 | 2. Sample from the resulting distribution 79 | 80 | Args: 81 | logits: Tensor of shape (batch_size, seq_len, vocab_size) containing unnormalized log probabilities 82 | temperature: Temperature parameter controlling randomness (higher = more random) 83 | noise_scale: Scale factor for the Gumbel noise 84 | 85 | Returns: 86 | A tuple containing: 87 | - tokens: Tensor of shape (batch_size, seq_len) containing sampled token indices 88 | - scores: Tensor of shape (batch_size, seq_len) containing log probabilities of selected tokens 89 | - modified_logits: Tensor of shape (batch_size, seq_len, vocab_size) containing temperature-scaled logits 90 | """ 91 | dtype = logits.dtype 92 | logits = logits.to(torch.float64) 93 | if temperature != 0: 94 | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits, dtype=torch.float64) + 1e-8) + 1e-8) 95 | logits = logits / temperature + noise_scale * gumbel_noise 96 | scores, tokens = logits.log_softmax(dim=-1).max(dim=-1) 97 | return tokens, scores.to(dtype), logits.to(dtype) 98 | 99 | --------------------------------------------------------------------------------