├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── asset ├── .DS_Store ├── SANA.png ├── hku.png ├── kvcache.jpg ├── llada.gif ├── mit_han.png ├── output.gif ├── overall_performance.jpg ├── parallel.gif ├── pseudo_code.jpg ├── sana.jpg └── speedup.jpg ├── dream ├── demo_multiturn_chat.py ├── eval.md ├── eval.py ├── eval_gsm8k.sh ├── eval_humaneval.sh ├── model │ ├── __init__.py │ ├── configuration_dream.py │ ├── generation_utils.py │ ├── generation_utils_block.py │ ├── modeling_dream.py │ └── tokenization_dream.py ├── postprocess_code.py └── sanitize.py ├── index.html ├── llada ├── chat.py ├── eval.md ├── eval_gsm8k.sh ├── eval_humaneval.sh ├── eval_llada.py ├── generate.py ├── model │ ├── __init__.py │ ├── configuration_llada.py │ └── modeling_llada.py ├── postprocess_code.py └── sanitize.py ├── paper └── fast_dllm.pdf └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | llada/output/ 6 | llada/evals_results/ 7 | dream/evals_results/ 8 | .DS_Store 9 | **/.DS_Store -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | Thank you for your interest in the Fast-DLLM project! We welcome all forms of contributions, including but not limited to: 4 | 5 | - Bug reports 6 | - Feature suggestions 7 | - Documentation improvements 8 | - Code fixes 9 | - New features 10 | 11 | ## Development Process 12 | 13 | 1. Fork the repository 14 | 2. Create your feature branch (`git checkout -b feature/AmazingFeature`) 15 | 3. Commit your changes (`git commit -m 'Add some AmazingFeature'`) 16 | 4. Push to the branch (`git push origin feature/AmazingFeature`) 17 | 5. Open a Pull Request 18 | 19 | ## Developer Certificate of Origin 20 | 21 | Version 1.1 22 | 23 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 24 | 25 | Everyone is permitted to copy and distribute verbatim copies of this 26 | license document, but changing it is not allowed. 27 | 28 | ### Developer's Certificate of Origin 1.1 29 | 30 | By making a contribution to this project, I certify that: 31 | 32 | (a) The contribution was created in whole or in part by me and I 33 | have the right to submit it under the open source license 34 | indicated in the file; or 35 | 36 | (b) The contribution is based upon previous work that, to the best 37 | of my knowledge, is covered under an appropriate open source 38 | license and I have the right under that license to submit that 39 | work with modifications, whether created in whole or in part 40 | by me, under the same open source license (unless I am 41 | permitted to submit under a different license), as indicated 42 | in the file; or 43 | 44 | (c) The contribution was provided directly to me by some other 45 | person who certified (a), (b) or (c) and I have not modified 46 | it. 47 | 48 | (d) I understand and agree that this project and the contribution 49 | are public and that a record of the contribution (including all 50 | personal information I submit with it, including my sign-off) is 51 | maintained indefinitely and may be redistributed consistent with 52 | this project or the open source license(s) involved. 53 | 54 | ## Code Style 55 | 56 | Please ensure your code follows the project's code style guidelines. We use the following tools to maintain code quality: 57 | 58 | - Code formatting tools 59 | - Code linting tools 60 | - Unit tests 61 | 62 | ## Submitting Pull Requests 63 | 64 | Before submitting a Pull Request, please ensure: 65 | 66 | 1. Your code passes all tests 67 | 2. You have updated relevant documentation 68 | 3. Your commit messages are clear and descriptive 69 | 4. Your code follows the project's code style guidelines 70 | 71 | ## Issue Reporting 72 | 73 | If you find any issues or have suggestions, please submit them through GitHub Issues. Before submitting an issue, please ensure: 74 | 75 | 1. The issue hasn't been reported already 76 | 2. You have provided sufficient information to reproduce the issue 77 | 3. You have attempted to resolve the issue yourself 78 | 79 | Thank you for contributing! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Nvidia 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-DLLM 2 | 3 | Fast-DLLM is a diffusion-based Large Language Model (LLM) inference acceleration framework that supports efficient inference for models like Dream and LLaDA. 4 | 5 | ## Project Structure 6 | 7 | ``` 8 | . 9 | ├── dream/ # Dream model related code 10 | ├── llada/ # LLaDA model related code 11 | └── .gitignore # Git ignore configuration 12 | ``` 13 | 14 | ## Features 15 | 16 | - Fast inference support for Dream and LLaDA models 17 | - Multiple inference optimization strategies 18 | - Code generation and evaluation capabilities 19 | - Interactive chat interface 20 | 21 | ## Installation 22 | 23 | 1. Clone the repository: 24 | ```bash 25 | git clone https://github.com/your-username/fast-dllm.git 26 | cd fast-dllm 27 | ``` 28 | 29 | 2. Install dependencies: 30 | ```bash 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Usage 35 | 36 | ### 1. Using LLaDA Model 37 | 38 | #### Interactive Chat 39 | ```bash 40 | python llada/chat.py --gen_length 128 --steps 128 --block_size 32 41 | ``` 42 | 43 | Parameter descriptions: 44 | - `--gen_length`: Maximum length of generated text 45 | - `--steps`: Number of sampling steps 46 | - `--block_size`: Cache block size 47 | - `--use_cache`: Whether to use cache 48 | - `--if_cache_position`: Whether to use dual cache 49 | - `--threshold`: Confidence threshold 50 | 51 | #### Model Evaluation 52 | For detailed evaluation instructions on GSM8K and HumanEval benchmarks, please refer to [LLaDA Evaluation Guide](llada/eval.md). 53 | 54 | ### 2. Using Dream Model 55 | 56 | For detailed evaluation instructions on GSM8K and HumanEval benchmarks, please refer to [Dream Evaluation Guide](dream/eval.md). 57 | 58 | ## Contributing 59 | 60 | Issues and Pull Requests are welcome! 61 | 62 | ## License 63 | 64 | This project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details. -------------------------------------------------------------------------------- /asset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/.DS_Store -------------------------------------------------------------------------------- /asset/SANA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/SANA.png -------------------------------------------------------------------------------- /asset/hku.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/hku.png -------------------------------------------------------------------------------- /asset/kvcache.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/kvcache.jpg -------------------------------------------------------------------------------- /asset/llada.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/llada.gif -------------------------------------------------------------------------------- /asset/mit_han.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/mit_han.png -------------------------------------------------------------------------------- /asset/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/output.gif -------------------------------------------------------------------------------- /asset/overall_performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/overall_performance.jpg -------------------------------------------------------------------------------- /asset/parallel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/parallel.gif -------------------------------------------------------------------------------- /asset/pseudo_code.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/pseudo_code.jpg -------------------------------------------------------------------------------- /asset/sana.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/sana.jpg -------------------------------------------------------------------------------- /asset/speedup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/asset/speedup.jpg -------------------------------------------------------------------------------- /dream/demo_multiturn_chat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | import torch 19 | from transformers import AutoModel, AutoTokenizer 20 | import time 21 | from model.modeling_dream import DreamModel 22 | 23 | import types 24 | # Load model and tokenizer 25 | 26 | # 从命令行读取use_cache 27 | use_cache = True if input("Use cache? (y/n): ").lower() == 'y' else False 28 | 29 | if use_cache: 30 | model_path = "Dream-org/Dream-v0-Instruct-7B" 31 | model = DreamModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True) 32 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 33 | model = model.to("cuda").eval() 34 | 35 | from model.generation_utils_block import DreamGenerationMixin 36 | model.diffusion_generate = types.MethodType(DreamGenerationMixin.diffusion_generate, model) 37 | model._sample = types.MethodType(DreamGenerationMixin._sample, model) 38 | else: 39 | model_path = "Dream-org/Dream-v0-Instruct-7B" 40 | model = DreamModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True) 41 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 42 | model = model.to("cuda").eval() 43 | 44 | 45 | # Initialize conversation history 46 | messages = [] 47 | 48 | print("Multi-turn conversation with Dream-v0-Instruct-7B") 49 | print("Type 'exit' to end the conversation") 50 | print("----------------------------------------------") 51 | 52 | while True: 53 | # Get user input 54 | user_input = input("You: ") 55 | 56 | # Check if user wants to exit 57 | if user_input.lower() == 'exit': 58 | print("Conversation ended.") 59 | break 60 | 61 | # Add user message to conversation history 62 | messages.append({"role": "user", "content": user_input}) 63 | 64 | # Format input with chat template 65 | inputs = tokenizer.apply_chat_template( 66 | messages, return_tensors="pt", return_dict=True, add_generation_prompt=True 67 | ) 68 | input_ids = inputs.input_ids.to(device="cuda") 69 | attention_mask = inputs.attention_mask.to(device="cuda") 70 | 71 | def generation_tokens_hook_func(step, x, logits): 72 | print(f"############ Step {step} ############") 73 | # print(tokenizer.decode(h[0].tolist())) 74 | print(tokenizer.decode(x[0].tolist()).split(tokenizer.eos_token)[0].replace(tokenizer.mask_token, " "), end="\r") 75 | time.sleep(0.01) 76 | return x 77 | 78 | # Generate response 79 | output = model.diffusion_generate( 80 | input_ids, 81 | attention_mask=attention_mask, 82 | max_new_tokens=128, 83 | output_history=True, 84 | return_dict_in_generate=True, 85 | steps=128, 86 | temperature=0., 87 | top_p=None, 88 | alg="entropy", 89 | alg_temp=0.1, 90 | top_k=None, 91 | block_length=32, 92 | # generation_tokens_hook_func=generation_tokens_hook_func 93 | ) 94 | 95 | # Process response 96 | generation = tokenizer.decode(output.sequences[0][len(input_ids[0]):].tolist()) 97 | generation = generation.split(tokenizer.eos_token)[0].strip() 98 | 99 | # Print response 100 | print("Model:", generation) 101 | 102 | # Add model response to conversation history 103 | messages.append({"role": "assistant", "content": generation}) 104 | 105 | 106 | '''An example conversation (maybe different due to randomness) 107 | <|im_start|>system 108 | You are a helpful assistant.<|im_end|> 109 | <|im_start|>user 110 | Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?<|im_end|> 111 | <|im_start|>assistant 112 | Janet sells 16 - 3 - 4 = 9 eggs per day. 113 | She makes 9 * $2 = $18 per day.<|im_end|> 114 | <|im_start|>user 115 | what if her duck lay three more eggs<|im_end|> 116 | <|im_start|>assistant 117 | If Janet's ducks lay three more eggs per day, she would have 16 + 3 = 19 eggs per day.<|im_end|> 118 | <|im_start|>user 119 | yes, so how many dollars she make<|im_end|> 120 | <|im_start|>assistant 121 | Janet sells 19 - 3 - 4 = 12 eggs per day. 122 | She makes 12 * $2 = $24 per day. 123 | ''' -------------------------------------------------------------------------------- /dream/eval.md: -------------------------------------------------------------------------------- 1 | # Dream Model Evaluation Guide 2 | 3 | This document provides detailed instructions for evaluating the Dream model on GSM8K math problem solving and HumanEval code generation tasks. 4 | 5 | ## Environment Setup 6 | 7 | Before running any evaluation, set the following environment variables: 8 | ```bash 9 | export HF_ALLOW_CODE_EVAL=1 10 | export HF_DATASETS_TRUST_REMOTE_CODE=true 11 | ``` 12 | 13 | ## GSM8K Evaluation 14 | 15 | GSM8K is a dataset of 8,000 grade school math problems designed to evaluate mathematical reasoning capabilities. 16 | 17 | ### Common Parameters 18 | 19 | ```bash 20 | task=gsm8k 21 | length=256 22 | block_length=32 23 | num_fewshot=5 24 | steps=$((length / block_length)) 25 | model="Dream-org/Dream-v0-Base-7B" 26 | ``` 27 | 28 | ### Evaluation Methods 29 | 30 | 1. **Baseline** 31 | ```bash 32 | accelerate launch eval.py --model dream \ 33 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${length},add_bos_token=true,alg=entropy,show_speed=True \ 34 | --tasks ${task} \ 35 | --num_fewshot ${num_fewshot} \ 36 | --batch_size 1 37 | ``` 38 | 39 | 2. **Prefix Cache** 40 | ```bash 41 | accelerate launch eval.py --model dream \ 42 | --model_args pretrained=${model},max_new_tokens=256,diffusion_steps=256,add_bos_token=true,alg=entropy,use_cache=true,show_speed=True \ 43 | --tasks ${task} \ 44 | --num_fewshot ${num_fewshot} \ 45 | --batch_size 1 46 | ``` 47 | 48 | 3. **Parallel Generation** 49 | ```bash 50 | accelerate launch eval.py --model dream \ 51 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,show_speed=True \ 52 | --tasks ${task} \ 53 | --num_fewshot ${num_fewshot} \ 54 | --batch_size 1 55 | ``` 56 | 57 | 4. **Prefix Cache + Parallel** 58 | ```bash 59 | accelerate launch eval.py --model dream \ 60 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true \ 61 | --tasks ${task} \ 62 | --num_fewshot ${num_fewshot} \ 63 | --batch_size 1 64 | ``` 65 | 66 | 5. **Dual Cache + Parallel** 67 | ```bash 68 | accelerate launch eval.py --model dream \ 69 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true,dual_cache=true \ 70 | --tasks ${task} \ 71 | --num_fewshot ${num_fewshot} \ 72 | --batch_size 1 73 | ``` 74 | 75 | ### Parameter Descriptions 76 | 77 | - `task`: Evaluation task (gsm8k) 78 | - `length`: Generation length 79 | - `block_length`: Block size for parallel generation 80 | - `num_fewshot`: Number of few-shot examples 81 | - `steps`: Number of generation steps 82 | - `model`: Model name (Dream-v0-Base-7B) 83 | - `use_cache`: Enable prefix cache 84 | - `dual_cache`: Enable dual cache 85 | - `threshold`: Confidence threshold for parallel generation 86 | - `show_speed`: Display speed metrics 87 | - `alg`: Generation algorithm (entropy or confidence_threshold) 88 | 89 | ## HumanEval Evaluation 90 | 91 | HumanEval is a dataset of 164 Python programming problems designed to evaluate code generation capabilities. 92 | 93 | ### Common Parameters 94 | 95 | ```bash 96 | task=humaneval 97 | length=256 98 | block_length=32 99 | steps=$((length / block_length)) 100 | model="Dream-org/Dream-v0-Base-7B" 101 | ``` 102 | 103 | ### Evaluation Methods 104 | 105 | 1. **Baseline** 106 | ```bash 107 | accelerate launch eval.py --model dream \ 108 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${length},add_bos_token=true,alg=entropy,show_speed=True,escape_until=true \ 109 | --tasks ${task} \ 110 | --batch_size 1 \ 111 | --output_path evals_results/baseline/humaneval-ns0-${length} --log_samples \ 112 | --confirm_run_unsafe_code 113 | ``` 114 | 115 | 2. **Prefix Cache** 116 | ```bash 117 | accelerate launch eval.py --model dream \ 118 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${length},add_bos_token=true,alg=entropy,use_cache=true,show_speed=True,escape_until=true \ 119 | --tasks ${task} \ 120 | --batch_size 1 \ 121 | --output_path evals_results/cache/humaneval-ns0-${length} --log_samples \ 122 | --confirm_run_unsafe_code 123 | ``` 124 | 125 | 3. **Parallel Generation** 126 | ```bash 127 | accelerate launch eval.py --model dream \ 128 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,show_speed=True,escape_until=true \ 129 | --tasks ${task} \ 130 | --batch_size 1 \ 131 | --output_path evals_results/parallel/humaneval-ns0-${length} --log_samples \ 132 | --confirm_run_unsafe_code 133 | ``` 134 | 135 | 4. **Prefix Cache + Parallel** 136 | ```bash 137 | accelerate launch eval.py --model dream \ 138 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true,escape_until=true \ 139 | --tasks ${task} \ 140 | --batch_size 1 \ 141 | --output_path evals_results/cache_parallel/humaneval-ns0-${length} --log_samples \ 142 | --confirm_run_unsafe_code 143 | ``` 144 | 145 | 5. **Dual Cache + Parallel** 146 | ```bash 147 | accelerate launch eval.py --model dream \ 148 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true,dual_cache=true,escape_until=true \ 149 | --tasks ${task} \ 150 | --batch_size 1 \ 151 | --output_path evals_results/dual_cache_parallel/humaneval-ns0-${length} --log_samples \ 152 | --confirm_run_unsafe_code 153 | ``` 154 | 155 | ### Additional Parameters for HumanEval 156 | 157 | - `escape_until`: Enable escape until for code generation 158 | - `confirm_run_unsafe_code`: Confirm running unsafe code for evaluation 159 | - `log_samples`: Log generated samples for analysis 160 | 161 | ### Post-processing 162 | 163 | For HumanEval evaluation, post-processing is required: 164 | ```bash 165 | python postprocess_code.py {the samples_xxx.jsonl file under output_path} 166 | ``` 167 | 168 | ## Notes 169 | 170 | 1. All evaluations use the Dream-v0-Base-7B model 171 | 2. Results are saved in the `evals_results` directory 172 | 3. For HumanEval, samples are logged for postprocessing 173 | 4. Speed metrics are shown for all evaluations 174 | 5. Different optimization strategies can be combined: 175 | 6. HumanEval evaluation requires additional safety confirmations -------------------------------------------------------------------------------- /dream/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | import logging 19 | import gc 20 | from datetime import timedelta 21 | from typing import List, Optional, Tuple, Type, TypeVar, Union 22 | import torch 23 | import torch.nn.functional as F 24 | import transformers 25 | from accelerate import ( 26 | Accelerator, 27 | InitProcessGroupKwargs, 28 | ) 29 | from datasets import Dataset 30 | from packaging import version 31 | from tqdm import tqdm 32 | 33 | from lm_eval import utils 34 | from lm_eval.api.instance import Instance 35 | from lm_eval.api.model import LM 36 | from lm_eval.api.registry import register_model 37 | from lm_eval.models.utils import get_dtype 38 | from lm_eval.__main__ import cli_evaluate 39 | from model.generation_utils_block import DreamGenerationMixin 40 | import types 41 | from model.configuration_dream import DreamConfig 42 | from model.modeling_dream import DreamModel 43 | import time 44 | import os 45 | import json 46 | 47 | eval_logger = logging.getLogger(__name__) 48 | T = TypeVar("T", bound="LM") 49 | 50 | @register_model("dream") 51 | class Dream(LM): 52 | def __init__( 53 | self, 54 | pretrained: Union[str, transformers.PreTrainedModel], 55 | batch_size: Optional[Union[int, str]] = 1, 56 | device: Optional[str] = "cuda", 57 | dtype: Optional[Union[str, torch.dtype]] = "auto", 58 | max_new_tokens: Optional[int] = 128, 59 | max_length: Optional[int] = 2048, 60 | add_bos_token: Optional[bool] = False, 61 | nll_type: Optional[str] = "mc", 62 | log_type: Optional[str] = "ftb", 63 | mc_num: Optional[int] = 128, 64 | classifier_free_guidance: Optional[float] = 1.0, 65 | sampling_eps: Optional[float] = 1e-3, 66 | diffusion_steps: Optional[int] = 128, 67 | trust_remote_code: Optional[bool] = True, 68 | parallelize: Optional[bool] = False, 69 | autogptq: Optional[Union[bool, str]] = False, 70 | temperature: Optional[float] = 0.0, 71 | top_p: Optional[float] = None, 72 | top_k: Optional[float] = None, 73 | alg: Optional[str] = "entropy", 74 | alg_temp: Optional[float] = 0.0, 75 | escape_until: Optional[bool] = False, 76 | threshold: Optional[float] = 0.9, 77 | apply_chat_template: Optional[bool] = False, 78 | use_cache: Optional[bool] = False, 79 | dual_cache: Optional[bool] = False, 80 | save_dir: Optional[str] = None, 81 | **kwargs, 82 | ) -> None: 83 | super().__init__() 84 | 85 | # prepare for parallelism 86 | assert isinstance(device, str) 87 | assert isinstance(pretrained, str) 88 | assert isinstance(batch_size, (int, str)) 89 | 90 | gpus = torch.cuda.device_count() 91 | accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) 92 | accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) 93 | if accelerator.num_processes > 1: 94 | self.accelerator = accelerator 95 | 96 | if "npu" in accelerator.device.type: 97 | gpus = torch.npu.device_count() 98 | 99 | # using one process with no model parallelism 100 | if not (parallelize or accelerator.num_processes > 1): 101 | # use user-passed device 102 | device_list = set( 103 | ["cuda", "cpu"] 104 | + [f"cuda:{i}" for i in range(gpus)] 105 | + ["mps", "mps:0"] 106 | + [f"npu:{i}" for i in range(gpus)] 107 | ) 108 | if device and device in device_list: 109 | self._device = torch.device(device) 110 | eval_logger.info(f"Using device '{device}'") 111 | if device in ("mps", "mps:0") and version.parse( 112 | torch.__version__ 113 | ) < version.parse("2.1"): 114 | raise RuntimeError( 115 | f"mps requires torch >= 2.1. You have {torch.__version__}" 116 | ) 117 | else: 118 | eval_logger.info("Device not specified") 119 | eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") 120 | self._device = ( 121 | torch.device("cuda") 122 | if torch.cuda.is_available() 123 | else torch.device("cpu") 124 | ) 125 | else: # Parallelism managed by accelerate 126 | if device != "cuda": 127 | eval_logger.info( 128 | f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." 129 | ) 130 | # TODO: include in warning that `load_in_8bit` etc. affect this too 131 | self._device = ( 132 | self.accelerator.device 133 | if hasattr(self, "accelerator") 134 | else torch.device(device) 135 | ) 136 | 137 | self.batch_size_per_gpu = batch_size 138 | if isinstance(batch_size, str): 139 | self.batch_size_per_gpu = int(batch_size) 140 | self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code) 141 | 142 | if isinstance(pretrained, str): 143 | if gpus >= 1 or str(self.device) == "mps": 144 | # TODO: can remove this whole snippet except in the mps case, perhaps? 145 | if not (parallelize or autogptq or hasattr(self, "accelerator")): 146 | # place model onto device requested manually, 147 | # if not using HF Accelerate or device_map 148 | # or any other option that preloads model onto device 149 | try: 150 | self.model.to(self.device) 151 | except ValueError: 152 | eval_logger.debug( 153 | "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." 154 | ) 155 | # multigpu data-parallel support when launched with accelerate 156 | if gpus > 1: 157 | if accelerator.num_processes > 1: 158 | if parallelize: 159 | eval_logger.warning( 160 | "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." 161 | ) 162 | elif gpus > accelerator.num_processes: 163 | eval_logger.warning( 164 | "WARNING: The number of total system GPUs does not match the number of spawned processes. " 165 | "If you would like to use data parallelism, please launch the script " 166 | "with 'accelerate launch *script*'. " 167 | f"Current run will proceed with {accelerator.num_processes} devices." 168 | ) 169 | if self.accelerator.is_local_main_process: 170 | eval_logger.info( 171 | f"Using {gpus} devices with data parallelism" 172 | ) 173 | 174 | self._device = torch.device(f"{accelerator.device}") 175 | self.accelerator = accelerator 176 | 177 | self._rank = self.accelerator.local_process_index 178 | self._world_size = self.accelerator.num_processes 179 | else: 180 | # if we aren't launching via accelerate, ditch 181 | self._rank = 0 182 | self._world_size = 1 183 | else: 184 | # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. 185 | eval_logger.warning( 186 | "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" 187 | ) 188 | self._rank = 0 189 | self._world_size = 1 190 | 191 | self.max_length = max_length 192 | self.add_bos_token = add_bos_token 193 | # generation params 194 | self.max_new_tokens = max_new_tokens 195 | self.diffusion_steps = diffusion_steps 196 | self.temperature = temperature 197 | self.top_p = top_p 198 | self.top_k = top_k 199 | self.alg = alg 200 | self.alg_temp = alg_temp 201 | self.escape_until = escape_until 202 | self.threshold = threshold 203 | # loglikelihood params 204 | self.nll_type = nll_type 205 | self.log_type = log_type 206 | self.mc_num = mc_num 207 | self.classifier_free_guidance = classifier_free_guidance 208 | self.sampling_eps = sampling_eps 209 | self.if_apply_chat_template = apply_chat_template 210 | self.use_cache = use_cache 211 | self.dual_cache = dual_cache 212 | self.generated_token_num = 0 213 | self.save_dir = save_dir 214 | @property 215 | def batch_size(self): 216 | return self.batch_size_per_gpu 217 | 218 | @property 219 | def device(self): 220 | return self._device 221 | 222 | @property 223 | def rank(self): 224 | return self._rank 225 | 226 | @property 227 | def world_size(self): 228 | return self._world_size 229 | 230 | def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code): 231 | self.model = ( 232 | DreamModel.from_pretrained( 233 | pretrained, 234 | torch_dtype=get_dtype(dtype), 235 | trust_remote_code=trust_remote_code, 236 | ) 237 | .eval() 238 | ).to(self.device) 239 | self.model.diffusion_generate = types.MethodType(DreamGenerationMixin.diffusion_generate, self.model) 240 | self.model._sample = types.MethodType(DreamGenerationMixin._sample, self.model) 241 | 242 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 243 | pretrained, trust_remote_code=trust_remote_code 244 | ) 245 | 246 | def tok_decode(self, tokens, skip_special_tokens=True): 247 | return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) 248 | 249 | def tok_encode(self, text, add_special_tokens=True): 250 | return self.tokenizer( 251 | text, return_tensors="pt", add_special_tokens=add_special_tokens 252 | ).input_ids 253 | @classmethod 254 | def create_from_arg_string( 255 | cls: Type[T], arg_string: str, additional_config: Optional[dict] = None 256 | ) -> T: 257 | """ 258 | Creates an instance of the LM class using the given argument string and additional config. 259 | 260 | Parameters: 261 | - arg_string: A string containing arguments in the format key1=value1,key2=value2. 262 | - additional_config: Optional dictionary containing additional configuration parameters. 263 | 264 | Returns: 265 | - Instance of the LM class. 266 | """ 267 | additional_config = {} if additional_config is None else additional_config 268 | args = utils.simple_parse_args_string(arg_string) 269 | args2 = {k: v for k, v in additional_config.items() if v is not None} 270 | return cls(**args, **args2) 271 | 272 | def apply_chat_template( 273 | self, chat_history, add_generation_prompt: bool = True 274 | ) -> str: 275 | """ 276 | Method to apply a chat template to a list of chat history between user and model. 277 | """ 278 | chat_templated = self.tokenizer.apply_chat_template( 279 | chat_history, 280 | tokenize=False, 281 | add_generation_prompt=add_generation_prompt, 282 | continue_final_message=not add_generation_prompt, 283 | ) 284 | 285 | return chat_templated 286 | 287 | @property 288 | def tokenizer_name(self) -> str: 289 | return self.tokenizer.name_or_path.replace("/", "__") 290 | 291 | def _generate_batch(self, prompts: List[str]) -> List[str]: 292 | if self.if_apply_chat_template: 293 | messages = [{"role": "user", "content": prompts[0]}] 294 | prompts = [self.apply_chat_template(messages)] 295 | else: 296 | if self.add_bos_token: 297 | prompts = [self.tokenizer.bos_token + p for p in prompts] 298 | # tokenize 299 | prompt_ids = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").input_ids 300 | if len(prompt_ids) > self.max_length-self.max_new_tokens: 301 | eval_logger.warning(f"Prompt length {len(prompt_ids)} is larger than {self.max_length-self.max_new_tokens}, cutoff on the left side") 302 | prompt_ids = prompt_ids[-(self.max_length-self.max_new_tokens):] 303 | 304 | attn_mask = prompt_ids.ne(self.tokenizer.pad_token_id) 305 | prompt_ids = prompt_ids.to(device=self.device) 306 | attn_mask = attn_mask.to(device=self.device) 307 | 308 | generation_ids = self.model.diffusion_generate( 309 | prompt_ids, 310 | attention_mask=attn_mask, 311 | max_new_tokens=self.max_new_tokens, 312 | output_history=False, 313 | return_dict_in_generate=True, 314 | steps=self.diffusion_steps, 315 | temperature=self.temperature, 316 | top_p=self.top_p, 317 | top_k=self.top_k, 318 | alg=self.alg, 319 | alg_temp=self.alg_temp, 320 | threshold=self.threshold, 321 | dual_cache=self.dual_cache, 322 | ) 323 | 324 | # decode 325 | self.generated_token_num += (generation_ids.sequences[0][prompt_ids.shape[1] :] != self.tokenizer.eos_token_id).sum().item() 326 | print(f"generated_token_num: {self.generated_token_num}") 327 | responses = [ 328 | self.tokenizer.decode(g[len(p) :].tolist()).split(self.tokenizer.eos_token)[0] 329 | for p, g in zip(prompt_ids, generation_ids.sequences) 330 | ] 331 | print('=' * 20) 332 | print('question: ', prompts[0]) 333 | print('answer: ', responses[0]) 334 | print('=' * 20, end='\n\n') 335 | return responses 336 | 337 | def generate_until(self, requests: List[Instance], disable_tqdm: bool = False): 338 | res = [] 339 | if self.use_cache: 340 | from model.generation_utils_block import DreamGenerationMixin 341 | self.model.diffusion_generate = types.MethodType(DreamGenerationMixin.diffusion_generate, self.model) 342 | self.model._sample = types.MethodType(DreamGenerationMixin._sample, self.model) 343 | else: 344 | from model.generation_utils import DreamGenerationMixin 345 | self.model.diffusion_generate = types.MethodType(DreamGenerationMixin.diffusion_generate, self.model) 346 | self.model._sample = types.MethodType(DreamGenerationMixin._sample, self.model) 347 | 348 | processed_count = 0 349 | if self.save_dir is not None: 350 | os.makedirs(self.save_dir, exist_ok=True) 351 | rank = self.rank 352 | save_path = os.path.join(self.save_dir, f'rank_{rank}.jsonl') 353 | print(f"save_path: {save_path}") 354 | if os.path.exists(save_path): 355 | print(f"load from {save_path}") 356 | with open(save_path, 'r', encoding='utf-8') as f: 357 | res = [json.loads(line) for line in f] 358 | processed_count = len(res) 359 | print(f"processed_count: {processed_count}") 360 | 361 | pbar = tqdm( 362 | total=len(requests), 363 | # disable=(disable_tqdm or (self.rank != 0)), 364 | desc="Running generate_until requests", 365 | ) 366 | start_time = time.time() 367 | for batch_idx in range(0, len(requests), self.batch_size): 368 | batch_requests = requests[batch_idx : batch_idx + self.batch_size] 369 | contexts, gen_args = zip(*[req.arguments for req in batch_requests]) 370 | 371 | if batch_idx < processed_count: 372 | pbar.update(len(contexts)) 373 | continue 374 | 375 | responses = self._generate_batch(contexts) 376 | if not self.escape_until: 377 | for i, r in enumerate(responses): 378 | for s in gen_args[0]['until']: 379 | r = r.split(s)[0] 380 | responses[i] = r 381 | 382 | # if self.rank == 0: 383 | # print(f"Context:\n{contexts[0]}\nResponse:\n{responses[0]}\n") 384 | 385 | res.extend(responses) 386 | pbar.update(len(contexts)) 387 | 388 | if self.save_dir is not None: 389 | # 增量保存新生成的答案 390 | for i, r in enumerate(responses): 391 | with open(save_path, 'a', encoding='utf-8') as f: 392 | f.write(json.dumps(r, ensure_ascii=False) + '\n') 393 | 394 | end_time = time.time() 395 | print(f"Time taken: {end_time - start_time} seconds") 396 | print(f"Generated token num: {self.generated_token_num}") 397 | print(f"Generated token num per second: {self.generated_token_num / (end_time - start_time)}") 398 | 399 | return res 400 | 401 | def _forward_process(self, batch): 402 | b, l = batch.shape 403 | # sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1 404 | u0 = torch.rand(1, device=batch.device, dtype=torch.float32) 405 | indices = torch.arange(b, device=batch.device).float() 406 | t = (u0 + indices / b) % 1 407 | 408 | p_mask = (1 - self.sampling_eps) * t + self.sampling_eps 409 | 410 | p_mask = p_mask[:, None].repeat(1, l) 411 | 412 | mask_indices = torch.rand((b, l), device=batch.device) < p_mask 413 | # always unmask bos and eos 414 | mask_indices[:, 0] = False 415 | mask_indices[:, -1] = False 416 | 417 | noisy_batch = torch.where(mask_indices, self.tokenizer.mask_token_id, batch) 418 | return noisy_batch, p_mask 419 | 420 | @torch.no_grad() 421 | def get_logits(self, batch, prompt_index): 422 | ''' 423 | prompt_index : 1D bool tensor, length=batch.shape[1] 424 | ''' 425 | if self.classifier_free_guidance > 1.: 426 | assert len(prompt_index) == batch.shape[1] 427 | prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) 428 | un_batch = batch.clone() 429 | un_batch[prompt_index] = self.tokenizer.mask_token_id 430 | batch = torch.cat([batch, un_batch]) 431 | 432 | input = batch 433 | 434 | with torch.amp.autocast('cuda', dtype=torch.bfloat16): 435 | logits = self.model(input).logits 436 | # since bos always unmask, the first logits will not be used 437 | logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) 438 | 439 | if self.classifier_free_guidance > 1.: 440 | logits, un_logits = torch.chunk(logits, 2, dim=0) 441 | logits = un_logits + self.cfg * (logits - un_logits) 442 | return logits[:, :batch.shape[1]] 443 | 444 | @torch.no_grad() 445 | def _eval_target_nll_mc(self, prefix, target): 446 | if prefix is None: 447 | seq = target[None, :] 448 | else: 449 | seq = torch.concatenate([prefix, target])[None, :] 450 | seq = seq.repeat((self.batch_size, 1)).to(self.device) 451 | 452 | if self.log_type == 'ftb': 453 | prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) 454 | else: 455 | prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix) 456 | 457 | loss_acc = [] 458 | for _ in range(max(self.mc_num // self.batch_size, 1)): 459 | perturbed_seq = seq.clone() 460 | # eval_logger.info("before noising") 461 | perturbed_seq_, p_mask = self._forward_process(seq) 462 | # eval_logger.info("end noising") 463 | if self.log_type == 'ftb': 464 | perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):] 465 | elif self.log_type == 'btf': 466 | perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)] 467 | elif self.log_type == 'union': 468 | perturbed_seq = perturbed_seq_ 469 | else: 470 | raise NotImplementedError(self.log_type) 471 | 472 | mask_indices = perturbed_seq == self.tokenizer.mask_token_id 473 | logits = self.get_logits(perturbed_seq, prompt_index) 474 | loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] 475 | loss = loss.sum() / self.batch_size 476 | loss_acc.append(loss.item()) 477 | 478 | return sum(loss_acc) / len(loss_acc) 479 | 480 | @torch.no_grad() 481 | def _eval_target_nll_ar(self, prefix, target): 482 | prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2 483 | assert self.log_type in ['ftb', 'btf'] 484 | assert self.nll_type in ['ar_ftb', 'ar_btf'] 485 | 486 | if self.log_type == 'ftb': 487 | prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1] 488 | else: 489 | prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1] 490 | 491 | if self.log_type == 'ftb': 492 | perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2 493 | else: 494 | perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1 495 | 496 | mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool) 497 | if self.nll_type == 'ar_ftb': 498 | mask_index = torch.triu(mask_index) 499 | else: 500 | mask_index = torch.tril(mask_index) 501 | perturbed_[mask_index] = self.tokenizer.mask_token_id 502 | if self.log_type == 'ftb': 503 | perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1) 504 | else: 505 | perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1) 506 | 507 | logits_ = [] 508 | num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1 509 | for i in range(num): 510 | end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq) 511 | perturbed_seq_ = perturbed_seq[i * self.batch_size: end] 512 | perturbed_seq_ = perturbed_seq_.to(self.device) 513 | if len(perturbed_seq_.shape) == 1: 514 | perturbed_seq_ = perturbed_seq_.unsqueeze(0) 515 | logits = self.get_logits(perturbed_seq_, prompt_index) 516 | logits_.append(logits.cpu()) 517 | logits = torch.cat(logits_, dim=0) 518 | 519 | temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool) 520 | if self.nll_type == 'ar_ftb': 521 | temp_index = torch.triu(temp_index, diagonal=1) 522 | else: 523 | temp_index = torch.tril(temp_index, diagonal=-1) 524 | mask_index[temp_index] = False 525 | if self.log_type == 'ftb': 526 | logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1) 527 | else: 528 | logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1) 529 | 530 | if self.log_type == 'ftb': 531 | loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item() 532 | else: 533 | loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item() 534 | return loss 535 | 536 | def _encode_pair(self, context, continuation): 537 | if self.add_bos_token: 538 | context = self.tokenizer.bos_token + context 539 | 540 | n_spaces = len(context) - len(context.rstrip()) 541 | if n_spaces > 0: 542 | continuation = context[-n_spaces:] + continuation 543 | context = context[:-n_spaces] 544 | 545 | whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id] 546 | context_enc = self.tokenizer.encode(context) 547 | 548 | context_enc_len = len(context_enc) 549 | continuation_enc = whole_enc[context_enc_len:] 550 | 551 | # by default truncate on the left 552 | cutoff_length = max(len(whole_enc) - self.max_length, 0) 553 | if cutoff_length > 0: 554 | eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side") 555 | context_remain = context_enc_len-cutoff_length 556 | if context_remain > 0: 557 | context_enc = context_enc[-context_remain:] 558 | else: 559 | eval_logger.warning(f"All context (prompt) is truncated.") 560 | context_enc = "" 561 | continuation_enc = whole_enc[-self.max_length:] 562 | return context_enc, continuation_enc 563 | 564 | def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: 565 | def _tokenize(e): 566 | prefix, target = self._encode_pair(e["prefix"], e["target"]) 567 | return { 568 | "prefix_text": e["prefix"], 569 | "target_text": e["target"], 570 | "prefix": prefix, 571 | "target": target, 572 | } 573 | 574 | ds = [] 575 | ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] 576 | ds = Dataset.from_list(ds) 577 | print(ds[0]) 578 | ds = ds.map(_tokenize) 579 | ds = ds.with_format("torch") 580 | 581 | out = [] 582 | with torch.no_grad(): 583 | for elem in tqdm(ds, desc="Computing likelihood..."): 584 | prefix = elem["prefix"] 585 | target = elem["target"] 586 | # likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py 587 | if self.nll_type == 'mc': 588 | ll = -self._eval_target_nll_mc(prefix, target) 589 | if self.log_type == 'union': 590 | ll = ll / (len(target) + len(prefix)) 591 | elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf': 592 | ll = -self._eval_target_nll_ar(prefix, target) 593 | else: 594 | raise NotImplementedError(self.nll_type) 595 | 596 | # TODO: greedy decoding 597 | is_target_greedy_dec = False 598 | 599 | out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) 600 | return out 601 | 602 | def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: 603 | raise NotImplementedError 604 | 605 | 606 | if __name__ == "__main__": 607 | cli_evaluate() -------------------------------------------------------------------------------- /dream/eval_gsm8k.sh: -------------------------------------------------------------------------------- 1 | # Set the environment variables first before running the command. 2 | export HF_ALLOW_CODE_EVAL=1 3 | export HF_DATASETS_TRUST_REMOTE_CODE=true 4 | 5 | task=gsm8k 6 | length=256 7 | block_length=32 8 | num_fewshot=5 9 | steps=$((length / block_length)) 10 | model="Dream-org/Dream-v0-Base-7B" 11 | 12 | # baseline 13 | accelerate launch eval.py --model dream \ 14 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${length},add_bos_token=true,alg=entropy,show_speed=True \ 15 | --tasks ${task} \ 16 | --num_fewshot ${num_fewshot} \ 17 | --batch_size 1 18 | 19 | # prefix cache 20 | accelerate launch eval.py --model dream \ 21 | --model_args pretrained=${model},max_new_tokens=256,diffusion_steps=256,add_bos_token=true,alg=entropy,use_cache=true,show_speed=True \ 22 | --tasks ${task} \ 23 | --num_fewshot ${num_fewshot} \ 24 | --batch_size 1 25 | 26 | # parallel 27 | accelerate launch eval.py --model dream \ 28 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,show_speed=True \ 29 | --tasks ${task} \ 30 | --num_fewshot ${num_fewshot} \ 31 | --batch_size 1 32 | 33 | # prefix cache+parallel 34 | accelerate launch eval.py --model dream \ 35 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true \ 36 | --tasks ${task} \ 37 | --num_fewshot ${num_fewshot} \ 38 | --batch_size 1 39 | 40 | # dual cache+parallel 41 | accelerate launch eval.py --model dream \ 42 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true,dual_cache=true \ 43 | --tasks ${task} \ 44 | --num_fewshot ${num_fewshot} \ 45 | --batch_size 1 -------------------------------------------------------------------------------- /dream/eval_humaneval.sh: -------------------------------------------------------------------------------- 1 | # Set the environment variables first before running the command. 2 | export HF_ALLOW_CODE_EVAL=1 3 | export HF_DATASETS_TRUST_REMOTE_CODE=true 4 | 5 | task=humaneval 6 | length=256 7 | block_length=32 8 | steps=$((length / block_length)) 9 | model="Dream-org/Dream-v0-Base-7B" 10 | 11 | # baseline 12 | accelerate launch eval.py --model dream \ 13 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${length},add_bos_token=true,alg=entropy,show_speed=True,escape_until=true \ 14 | --tasks ${task} \ 15 | --batch_size 1 \ 16 | --output_path evals_results/baseline/humaneval-ns0-${length} --log_samples \ 17 | --confirm_run_unsafe_code 18 | 19 | # prefix cache 20 | accelerate launch eval.py --model dream \ 21 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${length},add_bos_token=true,alg=entropy,use_cache=true,show_speed=True,escape_until=true \ 22 | --tasks ${task} \ 23 | --batch_size 1 \ 24 | --output_path evals_results/cache/humaneval-ns0-${length} --log_samples \ 25 | --confirm_run_unsafe_code 26 | 27 | # parallel 28 | accelerate launch eval.py --model dream \ 29 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,show_speed=True,escape_until=true \ 30 | --tasks ${task} \ 31 | --batch_size 1 \ 32 | --output_path evals_results/parallel/humaneval-ns0-${length} --log_samples \ 33 | --confirm_run_unsafe_code 34 | 35 | 36 | # prefix cache+parallel 37 | accelerate launch eval.py --model dream \ 38 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true,escape_until=true \ 39 | --tasks ${task} \ 40 | --batch_size 1 \ 41 | --output_path evals_results/cache_parallel/humaneval-ns0-${length} --log_samples \ 42 | --confirm_run_unsafe_code 43 | 44 | # dual cache+parallel 45 | accelerate launch eval.py --model dream \ 46 | --model_args pretrained=${model},max_new_tokens=${length},diffusion_steps=${steps},add_bos_token=true,alg=confidence_threshold,threshold=0.9,use_cache=true,dual_cache=true,escape_until=true \ 47 | --tasks ${task} \ 48 | --batch_size 1 \ 49 | --output_path evals_results/dual_cache_parallel/humaneval-ns0-${length} --log_samples \ 50 | --confirm_run_unsafe_code 51 | 52 | ## NOTICE: use postprocess for humaneval 53 | python postprocess_code.py {the samples_xxx.jsonl file under output_path} 54 | -------------------------------------------------------------------------------- /dream/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | 19 | from .configuration_dream import DreamConfig 20 | from .modeling_dream import DreamModel 21 | 22 | __all__ = ["DreamConfig", "DreamModel"] -------------------------------------------------------------------------------- /dream/model/configuration_dream.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | """Dream model configuration""" 19 | 20 | from transformers.configuration_utils import PretrainedConfig 21 | from transformers.modeling_rope_utils import rope_config_validation 22 | from transformers.utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class DreamConfig(PretrainedConfig): 29 | model_type = "Dream" 30 | keys_to_ignore_at_inference = ["past_key_values"] 31 | 32 | def __init__( 33 | self, 34 | vocab_size=151936, 35 | hidden_size=4096, 36 | intermediate_size=22016, 37 | num_hidden_layers=32, 38 | num_attention_heads=32, 39 | num_key_value_heads=32, 40 | hidden_act="silu", 41 | max_position_embeddings=32768, 42 | initializer_range=0.02, 43 | rms_norm_eps=1e-6, 44 | use_cache=False, # cache not used in diffusion 45 | tie_word_embeddings=False, 46 | rope_theta=10000.0, 47 | rope_scaling=None, 48 | use_sliding_window=False, 49 | sliding_window=4096, 50 | max_window_layers=28, 51 | attention_dropout=0.0, 52 | mask_token_id=151666, 53 | pad_token_id=151643, 54 | **kwargs, 55 | ): 56 | self.vocab_size = vocab_size 57 | self.max_position_embeddings = max_position_embeddings 58 | self.hidden_size = hidden_size 59 | self.intermediate_size = intermediate_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.use_sliding_window = use_sliding_window 63 | self.sliding_window = sliding_window if use_sliding_window else None 64 | self.max_window_layers = max_window_layers 65 | 66 | # for backward compatibility 67 | if num_key_value_heads is None: 68 | num_key_value_heads = num_attention_heads 69 | 70 | self.num_key_value_heads = num_key_value_heads 71 | self.hidden_act = hidden_act 72 | self.initializer_range = initializer_range 73 | self.rms_norm_eps = rms_norm_eps 74 | self.use_cache = use_cache 75 | self.rope_theta = rope_theta 76 | self.rope_scaling = rope_scaling 77 | self.attention_dropout = attention_dropout 78 | # Validate the correctness of rotary position embeddings parameters 79 | # BC: if there is a 'type' field, move it to 'rope_type'. 80 | if self.rope_scaling is not None and "type" in self.rope_scaling: 81 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 82 | rope_config_validation(self) 83 | 84 | super().__init__( 85 | tie_word_embeddings=tie_word_embeddings, 86 | **kwargs, 87 | ) 88 | self.mask_token_id = mask_token_id 89 | self.pad_token_id = pad_token_id 90 | -------------------------------------------------------------------------------- /dream/model/generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 18 | 19 | import time 20 | import warnings 21 | import copy 22 | from dataclasses import dataclass 23 | from typing import Any, Dict, Optional, Tuple, Union 24 | 25 | import torch 26 | import torch.distributions as dists 27 | from torch.nn import functional as F 28 | from transformers import __version__ 29 | from transformers.generation.configuration_utils import ( 30 | GenerationConfig 31 | ) 32 | from transformers.utils import ( 33 | ModelOutput, 34 | is_torchdynamo_compiling, 35 | logging, 36 | ) 37 | 38 | logger = logging.get_logger(__name__) 39 | 40 | 41 | def top_p_logits(logits, top_p=None): 42 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 43 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 44 | sorted_indices_to_remove = cumulative_probs > top_p 45 | # Shift the indices to the right to keep the first token above the threshold 46 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 47 | sorted_indices_to_remove[..., 0] = 0 48 | 49 | mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) 50 | mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) 51 | logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) 52 | return logits 53 | 54 | def top_k_logits(logits, top_k=None): 55 | top_k = min(top_k, logits.size(-1)) # Safety check 56 | # Remove all tokens with a probability less than the last token of the top-k 57 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 58 | logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) 59 | return logits 60 | 61 | 62 | def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): 63 | 64 | if temperature > 0: 65 | logits = logits / temperature 66 | if top_p is not None and top_p < 1: 67 | logits = top_p_logits(logits, top_p) 68 | if top_k is not None: 69 | logits = top_k_logits(logits, top_k) 70 | probs = torch.softmax(logits, dim=-1) 71 | 72 | if temperature > 0: 73 | try: 74 | x0 = dists.Categorical(probs=probs).sample() 75 | confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) 76 | except: 77 | confidence, x0 = probs.max(dim=-1) 78 | else: 79 | confidence, x0 = probs.max(dim=-1) 80 | 81 | if margin_confidence: 82 | sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) 83 | # Extract top1 and top2 probabilities 84 | top1_probs = sorted_probs[:, 0] 85 | top2_probs = sorted_probs[:, 1] 86 | # Calculate confidence as top1 - top2 87 | confidence = top1_probs - top2_probs 88 | 89 | if neg_entropy: 90 | epsilon = 1e-10 91 | log_probs = torch.log(probs + epsilon) 92 | confidence = torch.sum(probs * log_probs, dim=-1) 93 | 94 | return confidence, x0 95 | 96 | 97 | @dataclass 98 | class DreamModelOutput(ModelOutput): 99 | sequences: torch.LongTensor = None 100 | history: Optional[Tuple[torch.FloatTensor]] = None 101 | 102 | 103 | class DreamGenerationConfig(GenerationConfig): 104 | def __init__(self, **kwargs): 105 | self.temperature: float = kwargs.pop("temperature", 0.0) 106 | self.top_p: Optional[float] = kwargs.pop("top_p", None) 107 | self.top_k: Optional[int] = kwargs.pop("top_k", None) 108 | self.max_length = kwargs.pop("max_length", 20) 109 | self.max_new_tokens = kwargs.pop("max_new_tokens", None) 110 | # diffusion specific params 111 | self.eps: float = kwargs.pop("eps", 1e-3) 112 | self.steps: int = kwargs.pop("steps", 512) 113 | self.alg: str = kwargs.pop("alg", 'origin') 114 | self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) 115 | 116 | # Parameters that define the output variables of `generate` 117 | self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) 118 | self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) 119 | self.output_history: bool = kwargs.pop("output_history", False) 120 | 121 | # Special tokens that can be used at generation time 122 | self.mask_token_id = kwargs.pop("mask_token_id", None) 123 | self.pad_token_id = kwargs.pop("pad_token_id", None) 124 | self.bos_token_id = kwargs.pop("bos_token_id", None) 125 | self.eos_token_id = kwargs.pop("eos_token_id", None) 126 | 127 | # Wild card 128 | self.generation_kwargs = kwargs.pop("generation_kwargs", {}) 129 | 130 | # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub 131 | # interface. 132 | self._from_model_config = kwargs.pop("_from_model_config", False) 133 | self._commit_hash = kwargs.pop("_commit_hash", None) 134 | self.transformers_version = kwargs.pop("transformers_version", __version__) 135 | 136 | # Additional attributes without default values 137 | if not self._from_model_config: 138 | # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a 139 | # model's default configuration file 140 | for key, value in kwargs.items(): 141 | try: 142 | setattr(self, key, value) 143 | except AttributeError as err: 144 | logger.error(f"Can't set {key} with value {value} for {self}") 145 | raise err 146 | 147 | # Validate the values of the attributes 148 | self.validate(is_init=True) 149 | 150 | def validate(self, is_init=False): 151 | pass 152 | 153 | class DreamGenerationMixin: 154 | @staticmethod 155 | def _expand_inputs_for_generation( 156 | expand_size: int = 1, 157 | input_ids: Optional[torch.LongTensor] = None, 158 | attention_mask: Optional[torch.LongTensor] = None 159 | ) -> Tuple[torch.LongTensor, Dict[str, Any]]: 160 | """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" 161 | # Do not call torch.repeat_interleave if expand_size is 1 because it clones 162 | # the input tensor and thus requires more memory although no change is applied 163 | if expand_size == 1: 164 | return input_ids, attention_mask 165 | if input_ids is not None: 166 | input_ids = input_ids.repeat_interleave(expand_size, dim=0) 167 | if attention_mask is not None: 168 | attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) 169 | return input_ids, attention_mask 170 | 171 | def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): 172 | """Performs validation related to the resulting generated length""" 173 | 174 | # Can't throw warnings/exceptions during compilation 175 | if is_torchdynamo_compiling(): 176 | return 177 | 178 | # 1. Max length warnings related to poor parameterization 179 | if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: 180 | # 20 is the default max_length of the generation config 181 | warnings.warn( 182 | f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " 183 | "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " 184 | "generation.", 185 | UserWarning, 186 | ) 187 | if input_ids_length >= generation_config.max_length: 188 | input_ids_string = "input_ids" 189 | raise ValueError( 190 | f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" 191 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" 192 | " increasing `max_length` or, better yet, setting `max_new_tokens`." 193 | ) 194 | 195 | def _prepare_generated_length( 196 | self, 197 | generation_config, 198 | has_default_max_length, 199 | input_ids_length, 200 | ): 201 | """Prepared max and min length in generation configs to avoid clashes between similar attributes""" 202 | 203 | if generation_config.max_new_tokens is not None: 204 | if not has_default_max_length and generation_config.max_length is not None: 205 | logger.warning( 206 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 207 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 208 | "Please refer to the documentation for more information. " 209 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" 210 | ) 211 | generation_config.max_length = generation_config.max_new_tokens + input_ids_length 212 | 213 | elif has_default_max_length: 214 | if generation_config.max_length == DreamGenerationConfig().max_length: 215 | generation_config.max_length = generation_config.max_length + input_ids_length 216 | max_position_embeddings = getattr(self.config, "max_position_embeddings", None) 217 | if max_position_embeddings is not None: 218 | generation_config.max_length = min(generation_config.max_length, max_position_embeddings) 219 | 220 | return generation_config 221 | 222 | def _prepare_generation_config( 223 | self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict 224 | ) -> DreamGenerationConfig: 225 | """ 226 | Prepares the base generation config, then applies any generation configuration options from kwargs. This 227 | function handles retrocompatibility with respect to configuration files. 228 | """ 229 | # priority: `generation_config` argument > `model.generation_config` (the default generation config) 230 | using_model_generation_config = False 231 | if generation_config is None: 232 | generation_config = DreamGenerationConfig.from_model_config(self.config) 233 | using_model_generation_config = True 234 | 235 | # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` 236 | # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an 237 | # exception will be raised in `_validate_model_kwargs` 238 | if not is_torchdynamo_compiling(): 239 | generation_config = copy.deepcopy(generation_config) 240 | _kwargs = generation_config.update(**kwargs) 241 | # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model 242 | if not using_model_generation_config: 243 | if generation_config.bos_token_id is None: 244 | generation_config.bos_token_id = self.generation_config.bos_token_id 245 | if generation_config.eos_token_id is None: 246 | generation_config.eos_token_id = self.generation_config.eos_token_id 247 | if generation_config.pad_token_id is None: 248 | generation_config.pad_token_id = self.generation_config.pad_token_id 249 | if generation_config.mask_token_id is None: 250 | generation_config.mask_token_id = self.generation_config.mask_token_id 251 | 252 | return generation_config 253 | 254 | def _prepare_special_tokens( 255 | self, 256 | generation_config: DreamGenerationConfig, 257 | device: Optional[Union[torch.device, str]] = None, 258 | ): 259 | """ 260 | Prepares the special tokens for generation, overwriting the generation config with their processed versions 261 | converted to tensor. 262 | Note that `generation_config` is changed in place and stops being serializable after this method is called. 263 | That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the 264 | function). However, if called outside `generate`, consider creating a copy of `generation_config` first. 265 | """ 266 | 267 | # Convert special tokens to tensors 268 | def _tensor_or_none(token, device=None): 269 | if token is None: 270 | return token 271 | 272 | device = device if device is not None else self.device 273 | if isinstance(token, torch.Tensor): 274 | return token.to(device) 275 | return torch.tensor(token, device=device, dtype=torch.long) 276 | 277 | bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) 278 | eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) 279 | pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) 280 | mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) 281 | 282 | # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). 283 | if eos_token_tensor is not None and eos_token_tensor.ndim == 0: 284 | eos_token_tensor = eos_token_tensor.unsqueeze(0) 285 | 286 | # Set pad token if unset (and there are conditions to do so) 287 | if pad_token_tensor is None and eos_token_tensor is not None: 288 | pad_token_tensor = eos_token_tensor[0] 289 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") 290 | 291 | # Update generation config with the updated special tokens tensors 292 | # NOTE: this must be written into a different attribute name than the one holding the original special tokens 293 | # (in their non-tensor form), in order to enable end-to-end compilation. See 294 | # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations 295 | generation_config._bos_token_tensor = bos_token_tensor 296 | generation_config._eos_token_tensor = eos_token_tensor 297 | generation_config._pad_token_tensor = pad_token_tensor 298 | generation_config._mask_token_tensor = mask_token_tensor 299 | 300 | @torch.no_grad() 301 | def diffusion_generate( 302 | self, 303 | inputs: Optional[torch.Tensor] = None, 304 | generation_config: Optional[DreamGenerationConfig] = None, 305 | **kwargs, 306 | ) -> Union[DreamModelOutput, torch.LongTensor]: 307 | # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call 308 | generation_config = self._prepare_generation_config(generation_config, **kwargs) 309 | generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) 310 | generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) 311 | 312 | # 2. Define model inputs 313 | assert inputs is not None 314 | input_ids = inputs 315 | device = input_ids.device 316 | attention_mask = kwargs.pop("attention_mask", None) 317 | self._prepare_special_tokens(generation_config, device=device) 318 | 319 | # 3. Prepare `max_length`. 320 | input_ids_length = input_ids.shape[-1] 321 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 322 | generation_config = self._prepare_generated_length( 323 | generation_config=generation_config, 324 | has_default_max_length=has_default_max_length, 325 | input_ids_length=input_ids_length, 326 | ) 327 | 328 | self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) 329 | 330 | # 4. Check input_ids 331 | if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: 332 | warnings.warn( 333 | "You are calling .generate() with the `input_ids` being on a device type different" 334 | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" 335 | f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." 336 | " Please make sure that you have put `input_ids` to the" 337 | f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" 338 | " running `.generate()`.", 339 | UserWarning, 340 | ) 341 | if ( 342 | hasattr(generation_config, "pad_token_id") and 343 | torch.any(input_ids == generation_config.pad_token_id) and 344 | attention_mask is None 345 | ): 346 | warnings.warn( 347 | "Padding was detected but no attention mask is passed here. For correct " 348 | "generation results, please set `attention_mask` when batch-padding inputs.", 349 | UserWarning, 350 | ) 351 | 352 | input_ids, attention_mask = self._expand_inputs_for_generation( 353 | expand_size=generation_config.num_return_sequences, 354 | input_ids=input_ids, 355 | attention_mask=attention_mask 356 | ) 357 | threshold = kwargs.get("threshold", 0.9) 358 | 359 | result = self._sample( 360 | input_ids, 361 | attention_mask=attention_mask, 362 | generation_config=generation_config, 363 | generation_tokens_hook_func=generation_tokens_hook_func, 364 | generation_logits_hook_func=generation_logits_hook_func, 365 | threshold=threshold 366 | ) 367 | return result 368 | 369 | def _sample( 370 | self, 371 | input_ids: torch.LongTensor, 372 | attention_mask: Optional[torch.LongTensor], 373 | generation_config: DreamGenerationConfig, 374 | generation_tokens_hook_func, 375 | generation_logits_hook_func, 376 | threshold: Optional[float] = 0.9 377 | ) -> Union[DreamModelOutput, torch.LongTensor]: 378 | # init values 379 | output_history = generation_config.output_history 380 | return_dict_in_generate = generation_config.return_dict_in_generate 381 | max_length = generation_config.max_length 382 | mask_token_id = generation_config.mask_token_id 383 | steps = generation_config.steps 384 | eps = generation_config.eps 385 | alg = generation_config.alg 386 | alg_temp = generation_config.alg_temp 387 | temperature = generation_config.temperature 388 | top_p = generation_config.top_p 389 | top_k = generation_config.top_k 390 | 391 | histories = [] if (return_dict_in_generate and output_history) else None 392 | start_time = time.time() 393 | # pad input_ids to max_length 394 | x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) 395 | 396 | if attention_mask is not None and torch.any(attention_mask == 0.0): 397 | # we do not mask the [MASK] tokens so value = 1.0 398 | attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) 399 | tok_idx = attention_mask.long().cumsum(-1) - 1 400 | tok_idx.masked_fill_(attention_mask == 0, 1) 401 | # attention_mask is of shape [B, N] 402 | # broadcast to [B, 1, N, N] 403 | attention_mask = torch.logical_and( 404 | attention_mask.unsqueeze(1).unsqueeze(-2), 405 | attention_mask.unsqueeze(1).unsqueeze(-1), 406 | ) 407 | else: 408 | tok_idx = None 409 | attention_mask = "full" 410 | 411 | timesteps = torch.linspace(1, eps, steps + 1, device=x.device) 412 | 413 | # this allows user-defined token control of the intermediate steps 414 | x = generation_tokens_hook_func(None, x, None) 415 | i = 0 416 | if alg == 'confidence_threshold': 417 | mask_index = (x == mask_token_id) 418 | assert mask_index.sum() % steps == 0, "mask_index.sum() must be divisible by steps" 419 | assert x.shape[0] == 1, "batch size must be 1" 420 | 421 | number_transfer_tokens = mask_index.sum().item() // steps 422 | left_tokens_last_step = 0 423 | while i < steps: 424 | mask_index = (x == mask_token_id) 425 | logits = self(x, attention_mask, tok_idx).logits 426 | logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) 427 | 428 | # this allows user-defined logits control of the intermediate steps 429 | logits = generation_logits_hook_func(i, x, logits) 430 | 431 | mask_logits = logits[mask_index] 432 | if not alg == 'confidence_threshold': 433 | t = timesteps[i] 434 | s = timesteps[i + 1] 435 | 436 | if alg == 'origin': 437 | p_transfer = 1 - s / t if i < steps - 1 else 1 438 | x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id 439 | transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer 440 | _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k) 441 | x[mask_index] = x0.clone() 442 | elif alg == 'confidence_threshold': 443 | confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k) 444 | x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id 445 | x_[mask_index] = x0.clone() 446 | full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype) 447 | full_confidence[mask_index] = confidence 448 | current_transfer_tokens = number_transfer_tokens + left_tokens_last_step 449 | left_tokens_last_step = 0 450 | selected_confidence, select_index = torch.topk(full_confidence, current_transfer_tokens) 451 | transfer_index = torch.zeros_like(x, device=x.device, dtype=torch.bool) 452 | select_index = select_index.to(x.device) 453 | transfer_index[0, select_index[0]] = True 454 | for k in range(1, current_transfer_tokens): 455 | if selected_confidence[0, k] < threshold: 456 | if i < steps - 1: 457 | left_tokens_last_step += 1 458 | transfer_index[0, select_index[0, k]] = False 459 | else: 460 | number_transfer_tokens = 0 461 | steps += 1 462 | left_tokens_last_step += 1 463 | transfer_index[0, select_index[0, k]] = False 464 | 465 | x[transfer_index] = x_[transfer_index].clone() 466 | 467 | else: 468 | if alg == 'maskgit_plus': 469 | confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k) 470 | elif alg == 'topk_margin': 471 | confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True) 472 | elif alg == 'entropy': 473 | confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True) 474 | else: 475 | raise RuntimeError(f"Unknown alg: {alg}") 476 | num_mask_token = mask_index.sum() / mask_index.shape[0] 477 | number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token) 478 | full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype) 479 | full_confidence[mask_index] = confidence 480 | if number_transfer_tokens > 0: 481 | if alg_temp is None or alg_temp == 0: 482 | _, transfer_index = torch.topk(full_confidence, number_transfer_tokens) 483 | else: 484 | full_confidence = full_confidence / alg_temp 485 | full_confidence = F.softmax(full_confidence, dim=-1) 486 | transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens) 487 | x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id 488 | x_[mask_index] = x0.clone() 489 | row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index) 490 | x[row_indices,transfer_index] = x_[row_indices,transfer_index] 491 | 492 | # this allows user-defined token control of the intermediate steps 493 | x = generation_tokens_hook_func(i, x, logits) 494 | 495 | if histories is not None: 496 | histories.append(x.clone()) 497 | i += 1 498 | 499 | print(f'used steps: {steps}') 500 | end_time = time.time() 501 | print(f'used time: {end_time - start_time}') 502 | if return_dict_in_generate: 503 | return DreamModelOutput( 504 | sequences=x, 505 | history=histories, 506 | ) 507 | else: 508 | return x -------------------------------------------------------------------------------- /dream/model/generation_utils_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | import warnings 19 | import copy 20 | from dataclasses import dataclass 21 | from typing import Any, Dict, Optional, Tuple, Union 22 | 23 | import torch 24 | import torch.distributions as dists 25 | from torch.nn import functional as F 26 | from transformers import __version__ 27 | from transformers.generation.configuration_utils import ( 28 | GenerationConfig 29 | ) 30 | from transformers.utils import ( 31 | ModelOutput, 32 | is_torchdynamo_compiling, 33 | logging, 34 | ) 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | def get_num_transfer_tokens(mask_index, steps): 39 | ''' 40 | In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. 41 | Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), 42 | the expected number of tokens transitioned at each step should be consistent. 43 | 44 | This function is designed to precompute the number of tokens that need to be transitioned at each step. 45 | ''' 46 | mask_num = mask_index.sum(dim=1, keepdim=True) 47 | 48 | base = mask_num // steps 49 | remainder = mask_num % steps 50 | 51 | num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base 52 | 53 | for i in range(mask_num.size(0)): 54 | num_transfer_tokens[i, :remainder[i]] += 1 55 | 56 | return num_transfer_tokens 57 | 58 | 59 | def top_p_logits(logits, top_p=None): 60 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 61 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 62 | sorted_indices_to_remove = cumulative_probs > top_p 63 | # Shift the indices to the right to keep the first token above the threshold 64 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 65 | sorted_indices_to_remove[..., 0] = 0 66 | 67 | mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) 68 | mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) 69 | logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) 70 | return logits 71 | 72 | def top_k_logits(logits, top_k=None): 73 | top_k = min(top_k, logits.size(-1)) # Safety check 74 | # Remove all tokens with a probability less than the last token of the top-k 75 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 76 | logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) 77 | return logits 78 | 79 | 80 | def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False): 81 | 82 | if temperature > 0: 83 | logits = logits / temperature 84 | if top_p is not None and top_p < 1: 85 | logits = top_p_logits(logits, top_p) 86 | if top_k is not None: 87 | logits = top_k_logits(logits, top_k) 88 | probs = torch.softmax(logits, dim=-1) 89 | 90 | if temperature > 0: 91 | try: 92 | x0 = dists.Categorical(probs=probs).sample() 93 | confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) 94 | except: 95 | confidence, x0 = probs.max(dim=-1) 96 | else: 97 | confidence, x0 = probs.max(dim=-1) 98 | 99 | if margin_confidence: 100 | sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) 101 | # Extract top1 and top2 probabilities 102 | top1_probs = sorted_probs[:, 0] 103 | top2_probs = sorted_probs[:, 1] 104 | # Calculate confidence as top1 - top2 105 | confidence = top1_probs - top2_probs 106 | 107 | if neg_entropy: 108 | epsilon = 1e-10 109 | log_probs = torch.log(probs + epsilon) 110 | confidence = torch.sum(probs * log_probs, dim=-1) 111 | 112 | return confidence, x0 113 | 114 | 115 | @dataclass 116 | class DreamModelOutput(ModelOutput): 117 | sequences: torch.LongTensor = None 118 | history: Optional[Tuple[torch.FloatTensor]] = None 119 | 120 | 121 | class DreamGenerationConfig(GenerationConfig): 122 | def __init__(self, **kwargs): 123 | self.temperature: float = kwargs.pop("temperature", 0.0) 124 | self.top_p: Optional[float] = kwargs.pop("top_p", None) 125 | self.top_k: Optional[int] = kwargs.pop("top_k", None) 126 | self.max_length = kwargs.pop("max_length", 20) 127 | self.max_new_tokens = kwargs.pop("max_new_tokens", None) 128 | # diffusion specific params 129 | self.eps: float = kwargs.pop("eps", 1e-3) 130 | self.steps: int = kwargs.pop("steps", 512) 131 | self.alg: str = kwargs.pop("alg", 'origin') 132 | self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) 133 | 134 | # Parameters that define the output variables of `generate` 135 | self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) 136 | self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) 137 | self.output_history: bool = kwargs.pop("output_history", False) 138 | 139 | # Special tokens that can be used at generation time 140 | self.mask_token_id = kwargs.pop("mask_token_id", None) 141 | self.pad_token_id = kwargs.pop("pad_token_id", None) 142 | self.bos_token_id = kwargs.pop("bos_token_id", None) 143 | self.eos_token_id = kwargs.pop("eos_token_id", None) 144 | 145 | # Wild card 146 | self.generation_kwargs = kwargs.pop("generation_kwargs", {}) 147 | 148 | # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub 149 | # interface. 150 | self._from_model_config = kwargs.pop("_from_model_config", False) 151 | self._commit_hash = kwargs.pop("_commit_hash", None) 152 | self.transformers_version = kwargs.pop("transformers_version", __version__) 153 | 154 | # Additional attributes without default values 155 | if not self._from_model_config: 156 | # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a 157 | # model's default configuration file 158 | for key, value in kwargs.items(): 159 | try: 160 | setattr(self, key, value) 161 | except AttributeError as err: 162 | logger.error(f"Can't set {key} with value {value} for {self}") 163 | raise err 164 | 165 | # Validate the values of the attributes 166 | self.validate(is_init=True) 167 | 168 | def validate(self, is_init=False): 169 | pass 170 | 171 | class DreamGenerationMixin: 172 | @staticmethod 173 | def _expand_inputs_for_generation( 174 | expand_size: int = 1, 175 | input_ids: Optional[torch.LongTensor] = None, 176 | attention_mask: Optional[torch.LongTensor] = None 177 | ) -> Tuple[torch.LongTensor, Dict[str, Any]]: 178 | """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" 179 | # Do not call torch.repeat_interleave if expand_size is 1 because it clones 180 | # the input tensor and thus requires more memory although no change is applied 181 | if expand_size == 1: 182 | return input_ids, attention_mask 183 | if input_ids is not None: 184 | input_ids = input_ids.repeat_interleave(expand_size, dim=0) 185 | if attention_mask is not None: 186 | attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) 187 | return input_ids, attention_mask 188 | 189 | def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): 190 | """Performs validation related to the resulting generated length""" 191 | 192 | # Can't throw warnings/exceptions during compilation 193 | if is_torchdynamo_compiling(): 194 | return 195 | 196 | # 1. Max length warnings related to poor parameterization 197 | if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: 198 | # 20 is the default max_length of the generation config 199 | warnings.warn( 200 | f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " 201 | "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " 202 | "generation.", 203 | UserWarning, 204 | ) 205 | if input_ids_length >= generation_config.max_length: 206 | input_ids_string = "input_ids" 207 | raise ValueError( 208 | f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" 209 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" 210 | " increasing `max_length` or, better yet, setting `max_new_tokens`." 211 | ) 212 | 213 | def _prepare_generated_length( 214 | self, 215 | generation_config, 216 | has_default_max_length, 217 | input_ids_length, 218 | ): 219 | """Prepared max and min length in generation configs to avoid clashes between similar attributes""" 220 | 221 | if generation_config.max_new_tokens is not None: 222 | if not has_default_max_length and generation_config.max_length is not None: 223 | logger.warning( 224 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 225 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 226 | "Please refer to the documentation for more information. " 227 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" 228 | ) 229 | generation_config.max_length = generation_config.max_new_tokens + input_ids_length 230 | 231 | elif has_default_max_length: 232 | if generation_config.max_length == DreamGenerationConfig().max_length: 233 | generation_config.max_length = generation_config.max_length + input_ids_length 234 | max_position_embeddings = getattr(self.config, "max_position_embeddings", None) 235 | if max_position_embeddings is not None: 236 | generation_config.max_length = min(generation_config.max_length, max_position_embeddings) 237 | 238 | return generation_config 239 | 240 | def _prepare_generation_config( 241 | self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict 242 | ) -> DreamGenerationConfig: 243 | """ 244 | Prepares the base generation config, then applies any generation configuration options from kwargs. This 245 | function handles retrocompatibility with respect to configuration files. 246 | """ 247 | # priority: `generation_config` argument > `model.generation_config` (the default generation config) 248 | using_model_generation_config = False 249 | if generation_config is None: 250 | generation_config = DreamGenerationConfig.from_model_config(self.config) 251 | using_model_generation_config = True 252 | 253 | # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` 254 | # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an 255 | # exception will be raised in `_validate_model_kwargs` 256 | if not is_torchdynamo_compiling(): 257 | generation_config = copy.deepcopy(generation_config) 258 | _kwargs = generation_config.update(**kwargs) 259 | # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model 260 | if not using_model_generation_config: 261 | if generation_config.bos_token_id is None: 262 | generation_config.bos_token_id = self.generation_config.bos_token_id 263 | if generation_config.eos_token_id is None: 264 | generation_config.eos_token_id = self.generation_config.eos_token_id 265 | if generation_config.pad_token_id is None: 266 | generation_config.pad_token_id = self.generation_config.pad_token_id 267 | if generation_config.mask_token_id is None: 268 | generation_config.mask_token_id = self.generation_config.mask_token_id 269 | 270 | return generation_config 271 | 272 | def _prepare_special_tokens( 273 | self, 274 | generation_config: DreamGenerationConfig, 275 | device: Optional[Union[torch.device, str]] = None, 276 | ): 277 | """ 278 | Prepares the special tokens for generation, overwriting the generation config with their processed versions 279 | converted to tensor. 280 | Note that `generation_config` is changed in place and stops being serializable after this method is called. 281 | That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the 282 | function). However, if called outside `generate`, consider creating a copy of `generation_config` first. 283 | """ 284 | 285 | # Convert special tokens to tensors 286 | def _tensor_or_none(token, device=None): 287 | if token is None: 288 | return token 289 | 290 | device = device if device is not None else self.device 291 | if isinstance(token, torch.Tensor): 292 | return token.to(device) 293 | return torch.tensor(token, device=device, dtype=torch.long) 294 | 295 | bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) 296 | eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) 297 | pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) 298 | mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) 299 | 300 | # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). 301 | if eos_token_tensor is not None and eos_token_tensor.ndim == 0: 302 | eos_token_tensor = eos_token_tensor.unsqueeze(0) 303 | 304 | # Set pad token if unset (and there are conditions to do so) 305 | if pad_token_tensor is None and eos_token_tensor is not None: 306 | pad_token_tensor = eos_token_tensor[0] 307 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") 308 | 309 | # Update generation config with the updated special tokens tensors 310 | # NOTE: this must be written into a different attribute name than the one holding the original special tokens 311 | # (in their non-tensor form), in order to enable end-to-end compilation. See 312 | # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations 313 | generation_config._bos_token_tensor = bos_token_tensor 314 | generation_config._eos_token_tensor = eos_token_tensor 315 | generation_config._pad_token_tensor = pad_token_tensor 316 | generation_config._mask_token_tensor = mask_token_tensor 317 | 318 | @torch.no_grad() 319 | def diffusion_generate( 320 | self, 321 | inputs: Optional[torch.Tensor] = None, 322 | generation_config: Optional[DreamGenerationConfig] = None, 323 | **kwargs, 324 | ) -> Union[DreamModelOutput, torch.LongTensor]: 325 | # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call 326 | generation_config = self._prepare_generation_config(generation_config, **kwargs) 327 | 328 | # 2. Define model inputs 329 | assert inputs is not None 330 | input_ids = inputs 331 | device = input_ids.device 332 | attention_mask = kwargs.pop("attention_mask", None) 333 | self._prepare_special_tokens(generation_config, device=device) 334 | 335 | # 3. Prepare `max_length`. 336 | input_ids_length = input_ids.shape[-1] 337 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 338 | generation_config = self._prepare_generated_length( 339 | generation_config=generation_config, 340 | has_default_max_length=has_default_max_length, 341 | input_ids_length=input_ids_length, 342 | ) 343 | 344 | self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) 345 | 346 | # 4. Check input_ids 347 | if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: 348 | warnings.warn( 349 | "You are calling .generate() with the `input_ids` being on a device type different" 350 | f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" 351 | f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." 352 | " Please make sure that you have put `input_ids` to the" 353 | f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" 354 | " running `.generate()`.", 355 | UserWarning, 356 | ) 357 | if ( 358 | hasattr(generation_config, "pad_token_id") and 359 | torch.any(input_ids == generation_config.pad_token_id) and 360 | attention_mask is None 361 | ): 362 | warnings.warn( 363 | "Padding was detected but no attention mask is passed here. For correct " 364 | "generation results, please set `attention_mask` when batch-padding inputs.", 365 | UserWarning, 366 | ) 367 | 368 | input_ids, attention_mask = self._expand_inputs_for_generation( 369 | expand_size=generation_config.num_return_sequences, 370 | input_ids=input_ids, 371 | attention_mask=attention_mask 372 | ) 373 | threshold = kwargs.get("threshold", 0.9) 374 | block_length = kwargs.get("block_length", 32) 375 | dual_cache = kwargs.get("dual_cache", False) 376 | 377 | result = self._sample( 378 | input_ids, 379 | attention_mask=attention_mask, 380 | generation_config=generation_config, 381 | threshold=threshold, 382 | block_length=block_length, 383 | dual_cache=dual_cache 384 | ) 385 | return result 386 | 387 | def _sample( 388 | self, 389 | input_ids: torch.LongTensor, 390 | attention_mask: Optional[torch.LongTensor], 391 | generation_config: DreamGenerationConfig, 392 | threshold: Optional[float] = 0.9, 393 | block_length: Optional[int] = 32, 394 | dual_cache: bool = False, 395 | ) -> Union[DreamModelOutput, torch.LongTensor]: 396 | # init values 397 | 398 | output_history = generation_config.output_history 399 | return_dict_in_generate = generation_config.return_dict_in_generate 400 | max_length = generation_config.max_length 401 | mask_token_id = generation_config.mask_token_id 402 | steps = generation_config.steps 403 | temperature = generation_config.temperature 404 | top_p = generation_config.top_p 405 | top_k = generation_config.top_k 406 | alg = generation_config.alg 407 | alg_temp = generation_config.alg_temp 408 | 409 | histories = [] if (return_dict_in_generate and output_history) else None 410 | 411 | # pad input_ids to max_length 412 | x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) 413 | gen_length = max_length - input_ids.shape[1] 414 | 415 | # Handle block configuration 416 | if block_length is None: 417 | block_length = gen_length # Default: single block (original behavior) 418 | 419 | assert gen_length % block_length == 0, f"gen_length ({gen_length}) must be divisible by block_length ({block_length})" 420 | num_blocks = gen_length // block_length 421 | 422 | assert steps % num_blocks == 0, f"steps ({steps}) must be divisible by num_blocks ({num_blocks})" 423 | steps_per_block = steps // num_blocks 424 | timesteps = torch.linspace(1, generation_config.eps, steps_per_block + 1, device=x.device) 425 | 426 | if attention_mask is not None and torch.any(attention_mask == 0.0): 427 | # we do not mask the [MASK] tokens so value = 1.0 428 | attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0) 429 | tok_idx = attention_mask.long().cumsum(-1) - 1 430 | tok_idx.masked_fill_(attention_mask == 0, 1) 431 | # attention_mask is of shape [B, N] 432 | # broadcast to [B, 1, N, N] 433 | attention_mask = torch.logical_and( 434 | attention_mask.unsqueeze(1).unsqueeze(-2), 435 | attention_mask.unsqueeze(1).unsqueeze(-1), 436 | ) 437 | else: 438 | tok_idx = None 439 | attention_mask = "full" 440 | 441 | # Initialize cache for the prompt 442 | past_key_values = None 443 | 444 | # Process each block 445 | for num_block in range(num_blocks): 446 | 447 | current_block_start = input_ids.shape[1] + num_block * block_length 448 | current_block_end = current_block_start + block_length 449 | 450 | # update cache 451 | model_output = self(x, attention_mask, tok_idx, use_cache=True) 452 | past_key_values = model_output.past_key_values 453 | logits = model_output.logits 454 | logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) 455 | confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k) 456 | x[:, current_block_start] = x0[:, current_block_start] 457 | 458 | # Extract only previous block cache 459 | if not dual_cache: 460 | new_past_key_values = [] 461 | for i in range(len(past_key_values)): 462 | new_past_key_values.append(()) 463 | for j in range(len(past_key_values[i])): 464 | new_past_key_values[i] += (past_key_values[i][j][:, :current_block_start, :],) 465 | past_key_values = new_past_key_values 466 | else: 467 | replace_position = torch.zeros_like(x, dtype=torch.bool) 468 | replace_position[:, current_block_start:current_block_end] = 1 469 | 470 | i = 1 471 | while True: 472 | # Use cache for generation 473 | if dual_cache: 474 | mask_index = (x[:, current_block_start:current_block_end] == mask_token_id) 475 | else: 476 | mask_index = (x[:, current_block_start:] == mask_token_id) 477 | 478 | # Prepare attention mask for cached generation 479 | if attention_mask != "full": 480 | # Adjust attention mask for current position 481 | current_attention_mask = attention_mask[:, :, :, current_block_start:] 482 | else: 483 | current_attention_mask = attention_mask 484 | 485 | if dual_cache: 486 | model_output = self(x[:, current_block_start:current_block_end], current_attention_mask, 487 | tok_idx[:, current_block_start:current_block_end] if tok_idx is not None else None, 488 | past_key_values=past_key_values, use_cache=True, dual_cache=dual_cache, replace_position=replace_position) 489 | else: 490 | model_output = self(x[:, current_block_start:], current_attention_mask, 491 | tok_idx[:, current_block_start:] if tok_idx is not None else None, 492 | past_key_values=past_key_values, use_cache=True) 493 | logits = model_output.logits 494 | logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) 495 | if alg == 'confidence_threshold': 496 | mask_logits = logits[mask_index] 497 | 498 | confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k) 499 | 500 | if dual_cache: 501 | x_ = torch.zeros_like(x[:, current_block_start:current_block_end], device=self.device, dtype=torch.long) + mask_token_id 502 | full_confidence = torch.full_like(x[:, current_block_start:current_block_end], -torch.inf, device=self.device, dtype=logits.dtype) 503 | else: 504 | x_ = torch.zeros_like(x[:, current_block_start:], device=self.device, dtype=torch.long) + mask_token_id 505 | full_confidence = torch.full_like(x[:, current_block_start:], -torch.inf, device=self.device, dtype=logits.dtype) 506 | 507 | x_[mask_index] = x0.clone() 508 | full_confidence[mask_index] = confidence 509 | full_confidence[:, block_length:] = -torch.inf 510 | 511 | current_transfer_tokens = (x[:, current_block_start:current_block_end] == mask_token_id).sum() 512 | 513 | selected_confidence, select_index = torch.topk(full_confidence, current_transfer_tokens) 514 | transfer_index = torch.zeros_like(x_, device=x.device, dtype=torch.bool) 515 | 516 | select_index = select_index.to(x.device) 517 | transfer_index[0, select_index[0]] = True 518 | for k in range(1, current_transfer_tokens): 519 | if selected_confidence[0, k] < threshold: 520 | transfer_index[0, select_index[0, k]] = False 521 | if dual_cache: 522 | x[:, current_block_start:current_block_end][transfer_index] = x_[transfer_index] 523 | else: 524 | x[:, current_block_start:][transfer_index] = x_[transfer_index] 525 | else: 526 | if i == steps_per_block: 527 | break 528 | t = timesteps[i] 529 | s = timesteps[i + 1] 530 | mask_index[:, block_length:] = False 531 | mask_logits = logits[mask_index] 532 | confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True) 533 | num_mask_token = mask_index.sum() / mask_index.shape[0] 534 | number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps_per_block - 1 else int(num_mask_token) 535 | if dual_cache: 536 | full_confidence = torch.full_like(x[:, current_block_start:current_block_end], -torch.inf, device=self.device, dtype=logits.dtype) 537 | else: 538 | full_confidence = torch.full_like(x[:, current_block_start:], -torch.inf, device=self.device, dtype=logits.dtype) 539 | full_confidence[mask_index] = confidence 540 | full_confidence[:, block_length:] = -torch.inf 541 | 542 | if number_transfer_tokens > 0: 543 | if alg_temp is None or alg_temp == 0: 544 | _, transfer_index = torch.topk(full_confidence, number_transfer_tokens) 545 | else: 546 | full_confidence = full_confidence / alg_temp 547 | full_confidence = F.softmax(full_confidence, dim=-1) 548 | transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens) 549 | if dual_cache: 550 | x_ = torch.zeros_like(x[:, current_block_start:current_block_end], device=self.device, dtype=torch.long) + mask_token_id 551 | else: 552 | x_ = torch.zeros_like(x[:, current_block_start:], device=self.device, dtype=torch.long) + mask_token_id 553 | x_[mask_index] = x0.clone() 554 | row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index) 555 | if dual_cache: 556 | x[:, current_block_start:current_block_end][row_indices,transfer_index] = x_[row_indices,transfer_index] 557 | else: 558 | x[:, current_block_start:][row_indices,transfer_index] = x_[row_indices,transfer_index] 559 | i += 1 560 | 561 | if (x[:, current_block_start:current_block_end] == mask_token_id).sum() == 0: 562 | break 563 | 564 | 565 | if return_dict_in_generate: 566 | return DreamModelOutput( 567 | sequences=x, 568 | history=histories, 569 | ) 570 | else: 571 | return x -------------------------------------------------------------------------------- /dream/model/tokenization_dream.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 18 | 19 | """Tokenization classes for Dream.""" 20 | 21 | import json 22 | import os 23 | import unicodedata 24 | from functools import lru_cache 25 | from typing import Optional, Tuple 26 | 27 | import regex as re 28 | 29 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 30 | from transformers.utils import logging 31 | 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | VOCAB_FILES_NAMES = { 36 | "vocab_file": "vocab.json", 37 | "merges_file": "merges.txt", 38 | } 39 | 40 | 41 | MAX_MODEL_INPUT_SIZES = {"dream/dream-tokenizer": 32768} 42 | 43 | PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 44 | 45 | 46 | @lru_cache() 47 | # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode 48 | def bytes_to_unicode(): 49 | """ 50 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control 51 | characters the bpe code barfs on. 52 | 53 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab 54 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for 55 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup 56 | tables between utf-8 bytes and unicode strings. 57 | """ 58 | bs = ( 59 | list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 60 | ) 61 | cs = bs[:] 62 | n = 0 63 | for b in range(2**8): 64 | if b not in bs: 65 | bs.append(b) 66 | cs.append(2**8 + n) 67 | n += 1 68 | cs = [chr(n) for n in cs] 69 | return dict(zip(bs, cs)) 70 | 71 | 72 | # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs 73 | def get_pairs(word): 74 | """ 75 | Return set of symbol pairs in a word. 76 | 77 | Word is represented as tuple of symbols (symbols being variable-length strings). 78 | """ 79 | pairs = set() 80 | prev_char = word[0] 81 | for char in word[1:]: 82 | pairs.add((prev_char, char)) 83 | prev_char = char 84 | return pairs 85 | 86 | 87 | class DreamTokenizer(PreTrainedTokenizer): 88 | """ 89 | Construct a Dream tokenizer. Based on byte-level Byte-Pair-Encoding. 90 | 91 | Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will 92 | be encoded differently whether it is at the beginning of the sentence (without space) or not: 93 | 94 | ```python 95 | >>> from transformers import AutoTokenizer 96 | 97 | >>> tokenizer = AutoTokenizer.from_pretrained("Dream-org/Dream-v0-Base-7B", trust_remote_code=True) 98 | >>> tokenizer("Hello world")["input_ids"] 99 | [9707, 1879] 100 | 101 | >>> tokenizer(" Hello world")["input_ids"] 102 | [21927, 1879] 103 | ``` 104 | This is expected. 105 | 106 | You should not use GPT2Tokenizer instead, because of the different pretokenization rules. 107 | 108 | This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to 109 | this superclass for more information regarding those methods. 110 | 111 | Args: 112 | vocab_file (`str`): 113 | Path to the vocabulary file. 114 | merges_file (`str`): 115 | Path to the merges file. 116 | errors (`str`, *optional*, defaults to `"replace"`): 117 | Paradigm to follow when decoding bytes to UTF-8. See 118 | [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. 119 | unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 120 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 121 | token instead. 122 | bos_token (`str`, *optional*): 123 | The beginning of sequence token. Not applicable for this tokenizer. 124 | eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 125 | The end of sequence token. 126 | pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): 127 | The token used for padding, for example when batching sequences of different lengths. 128 | clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): 129 | Whether or not the model should cleanup the spaces that were added when splitting the input text during the 130 | tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces. 131 | split_special_tokens (`bool`, *optional*, defaults to `False`): 132 | Whether or not the special tokens should be split during the tokenization process. The default behavior is 133 | to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") = 134 | ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<', 135 | '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment. 136 | """ 137 | 138 | vocab_files_names = VOCAB_FILES_NAMES 139 | model_input_names = ["input_ids", "attention_mask"] 140 | 141 | def __init__( 142 | self, 143 | vocab_file, 144 | merges_file, 145 | errors="replace", 146 | unk_token="<|endoftext|>", 147 | bos_token=None, 148 | eos_token="<|endoftext|>", 149 | pad_token="<|endoftext|>", 150 | clean_up_tokenization_spaces=False, 151 | split_special_tokens=False, 152 | **kwargs, 153 | ): 154 | # Dream vocab does not contain control tokens; added tokens need to be special 155 | bos_token = ( 156 | AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) 157 | if isinstance(bos_token, str) 158 | else bos_token 159 | ) 160 | eos_token = ( 161 | AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) 162 | if isinstance(eos_token, str) 163 | else eos_token 164 | ) 165 | unk_token = ( 166 | AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) 167 | if isinstance(unk_token, str) 168 | else unk_token 169 | ) 170 | pad_token = ( 171 | AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) 172 | if isinstance(pad_token, str) 173 | else pad_token 174 | ) 175 | 176 | with open(vocab_file, encoding="utf-8") as vocab_handle: 177 | self.encoder = json.load(vocab_handle) 178 | self.decoder = {v: k for k, v in self.encoder.items()} 179 | self.errors = errors # how to handle errors in decoding 180 | self.byte_encoder = bytes_to_unicode() 181 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 182 | bpe_merges = [] 183 | with open(merges_file, encoding="utf-8") as merges_handle: 184 | for i, line in enumerate(merges_handle): 185 | line = line.strip() 186 | if (i == 0 and line.startswith("#version:")) or not line: 187 | continue 188 | bpe_merges.append(tuple(line.split())) 189 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 190 | # NOTE: the cache can grow without bound and will get really large for long running processes 191 | # (esp. for texts of language that do not use space between word, e.g. Chinese); technically 192 | # not a memory leak but appears as one. 193 | # GPT2Tokenizer has the same problem, so let's be consistent. 194 | self.cache = {} 195 | 196 | self.pat = re.compile(PRETOKENIZE_REGEX) 197 | 198 | if kwargs.get("add_prefix_space", False): 199 | logger.warning_once( 200 | f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." 201 | ) 202 | 203 | super().__init__( 204 | errors=errors, 205 | bos_token=bos_token, 206 | eos_token=eos_token, 207 | pad_token=pad_token, 208 | unk_token=unk_token, 209 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 210 | split_special_tokens=split_special_tokens, 211 | **kwargs, 212 | ) 213 | 214 | @property 215 | def vocab_size(self) -> int: 216 | return len(self.encoder) 217 | 218 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab 219 | def get_vocab(self): 220 | return dict(self.encoder, **self.added_tokens_encoder) 221 | 222 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe 223 | def bpe(self, token): 224 | if token in self.cache: 225 | return self.cache[token] 226 | word = tuple(token) 227 | pairs = get_pairs(word) 228 | 229 | if not pairs: 230 | return token 231 | 232 | while True: 233 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 234 | if bigram not in self.bpe_ranks: 235 | break 236 | first, second = bigram 237 | new_word = [] 238 | i = 0 239 | while i < len(word): 240 | try: 241 | j = word.index(first, i) 242 | except ValueError: 243 | new_word.extend(word[i:]) 244 | break 245 | else: 246 | new_word.extend(word[i:j]) 247 | i = j 248 | 249 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 250 | new_word.append(first + second) 251 | i += 2 252 | else: 253 | new_word.append(word[i]) 254 | i += 1 255 | new_word = tuple(new_word) 256 | word = new_word 257 | if len(word) == 1: 258 | break 259 | else: 260 | pairs = get_pairs(word) 261 | word = " ".join(word) 262 | self.cache[token] = word 263 | return word 264 | 265 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize 266 | def _tokenize(self, text): 267 | """Tokenize a string.""" 268 | bpe_tokens = [] 269 | for token in re.findall(self.pat, text): 270 | token = "".join( 271 | self.byte_encoder[b] for b in token.encode("utf-8") 272 | ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) 273 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 274 | return bpe_tokens 275 | 276 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id 277 | def _convert_token_to_id(self, token): 278 | """Converts a token (str) in an id using the vocab.""" 279 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 280 | 281 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token 282 | def _convert_id_to_token(self, index): 283 | """Converts an index (integer) in a token (str) using the vocab.""" 284 | return self.decoder.get(index) 285 | 286 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string 287 | def convert_tokens_to_string(self, tokens): 288 | """Converts a sequence of tokens (string) in a single string.""" 289 | text = "".join(tokens) 290 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 291 | return text 292 | 293 | def decode( 294 | self, 295 | token_ids, 296 | skip_special_tokens: bool = False, 297 | clean_up_tokenization_spaces: Optional[bool] = False, 298 | spaces_between_special_tokens: bool = False, 299 | **kwargs, 300 | ) -> str: 301 | # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers 302 | # and cannot be configured elsewhere, but it should default to False for DreamTokenizer 303 | return super().decode( 304 | token_ids, 305 | skip_special_tokens=skip_special_tokens, 306 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 307 | spaces_between_special_tokens=spaces_between_special_tokens, 308 | **kwargs, 309 | ) 310 | 311 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary 312 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 313 | if not os.path.isdir(save_directory): 314 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 315 | return 316 | vocab_file = os.path.join( 317 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 318 | ) 319 | merge_file = os.path.join( 320 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] 321 | ) 322 | 323 | with open(vocab_file, "w", encoding="utf-8") as f: 324 | f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") 325 | 326 | index = 0 327 | with open(merge_file, "w", encoding="utf-8") as writer: 328 | writer.write("#version: 0.2\n") 329 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 330 | if index != token_index: 331 | logger.warning( 332 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." 333 | " Please check that the tokenizer is not corrupted!" 334 | ) 335 | index = token_index 336 | writer.write(" ".join(bpe_tokens) + "\n") 337 | index += 1 338 | 339 | return vocab_file, merge_file 340 | 341 | def prepare_for_tokenization(self, text, **kwargs): 342 | text = unicodedata.normalize("NFC", text) 343 | return (text, kwargs) -------------------------------------------------------------------------------- /dream/postprocess_code.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | import evaluate as hf_evaluate 19 | import os 20 | import sys 21 | from sanitize import sanitize 22 | 23 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 24 | pass_at_k = hf_evaluate.load("code_eval") 25 | 26 | def pass_at_1(references, predictions): 27 | return pass_at_k.compute( 28 | references=references, 29 | predictions=predictions, 30 | k=[1], 31 | )[0]["pass@1"] 32 | 33 | import json 34 | 35 | 36 | def read_jsonl(file_path): 37 | data = [] 38 | with open(file_path, 'r') as file: 39 | for line in file: 40 | data.append(json.loads(line)) 41 | return data 42 | 43 | file_path = sys.argv[1] 44 | data = read_jsonl(file_path) 45 | 46 | references = [sample['target'] for sample in data] 47 | 48 | predictions = [[sanitize(sample['doc']['prompt'] + "\n" + sample['resps'][0][0].split('```python\n', 1)[-1].split('```')[0], 49 | sample['doc']["entry_point"])] 50 | for sample in data] 51 | 52 | pass_at_1s = [pass_at_1([reference], [prediction]) for reference, prediction in zip(references, predictions)] 53 | print(sum(pass_at_1s)/len(pass_at_1s)) 54 | 55 | def write_jsonl(data, file_path): 56 | with open(file_path, 'w') as file: 57 | for item in data: 58 | file.write(json.dumps(item) + '\n') 59 | 60 | res = [{"task_id": sample['doc']['task_id'], "completion": pred, "pass_at_1": res} 61 | for sample, pred, res in zip(data, predictions, pass_at_1s)] 62 | write_jsonl(res, file_path+'.cleaned') -------------------------------------------------------------------------------- /dream/sanitize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | """Post-processing LLM-generated Python code implemented using tree-sitter.""" 18 | 19 | import os 20 | import sys 21 | import pathlib 22 | 23 | ROOT = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))]) 25 | 26 | import ast 27 | import traceback 28 | 29 | from typing import Dict, List, Optional, Set, Tuple 30 | 31 | def refine_text(text: str) -> str: 32 | text = text.replace("\t", " ") 33 | text = text.replace("\r\n", "\n").replace("\r", "\n") 34 | return text.strip() + "\n" 35 | 36 | def syntax_check(code, verbose = False): 37 | try: 38 | ast.parse(code) 39 | return True 40 | except (SyntaxError, MemoryError): 41 | if verbose: 42 | traceback.print_exc() 43 | return False 44 | 45 | def extract_longest_valid_code(text: str) -> str: 46 | lines = text.splitlines() 47 | 48 | if len(lines) > 100: 49 | lines = lines[:100] 50 | max_valid_lines = 0 51 | max_valid_snippet = "" 52 | 53 | for i in range(len(lines)): 54 | for j in range(i, len(lines)): 55 | current_snippet = "\n".join(lines[i:j+1]) 56 | if syntax_check(current_snippet): 57 | valid_line_count = sum(1 for line in lines[i:j+1] if line.strip()) 58 | if valid_line_count > max_valid_lines: 59 | max_valid_lines = valid_line_count 60 | max_valid_snippet = current_snippet 61 | 62 | return max_valid_snippet 63 | 64 | def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: 65 | name2deps = {} 66 | for name, node in nodes: 67 | deps = set() 68 | stack = [node] 69 | while stack: 70 | current = stack.pop() 71 | for child in ast.iter_child_nodes(current): 72 | if isinstance(child, ast.Name): 73 | deps.add(child.id) 74 | elif isinstance(child, ast.Attribute): 75 | deps.add(child.attr) 76 | else: 77 | stack.append(child) 78 | name2deps[name] = deps 79 | return name2deps 80 | 81 | def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]: 82 | visited = set() 83 | to_visit = [entrypoint] 84 | 85 | while to_visit: 86 | current = to_visit.pop(0) 87 | if current not in visited: 88 | visited.add(current) 89 | to_visit.extend(call_graph.get(current, set()) - visited) 90 | 91 | return visited 92 | 93 | def get_definition_name(node: ast.AST) -> Optional[str]: 94 | if isinstance(node, (ast.FunctionDef, ast.ClassDef)): 95 | return node.name 96 | elif isinstance(node, ast.Assign): 97 | targets = node.targets 98 | if targets and isinstance(targets[0], ast.Name): 99 | return targets[0].id 100 | return None 101 | 102 | def has_return_statement(node: ast.AST) -> bool: 103 | return any(isinstance(n, ast.Return) for n in ast.walk(node)) 104 | 105 | def sanitize(text: str, entrypoint: Optional[str] = None) -> str: 106 | 107 | text = refine_text(text) 108 | 109 | # text = python_extract(text) 110 | 111 | code = extract_longest_valid_code(text) 112 | tree = ast.parse(code) 113 | 114 | definitions = {} 115 | 116 | imports = [] 117 | 118 | for node in tree.body: 119 | if isinstance(node, (ast.Import, ast.ImportFrom)): 120 | imports.append(node) 121 | elif isinstance(node, ast.ClassDef): 122 | name = node.name 123 | definitions[name] = ('class', node) 124 | elif isinstance(node, ast.FunctionDef): 125 | name = node.name 126 | if has_return_statement(node): 127 | definitions[name] = ('function', node) 128 | elif isinstance(node, ast.Assign): 129 | name = get_definition_name(node) 130 | if name: 131 | definitions[name] = ('variable', node) 132 | 133 | if entrypoint: 134 | name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) 135 | reachable = get_function_dependency(entrypoint, name2deps) 136 | 137 | sanitized_output = [] 138 | 139 | for node in imports: 140 | sanitized_output.append(ast.unparse(node)) 141 | 142 | for name, (_, node) in definitions.items(): 143 | if not entrypoint or name in reachable: 144 | sanitized_output.append(ast.unparse(node)) 145 | 146 | return "\n".join(sanitized_output) -------------------------------------------------------------------------------- /llada/chat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from LLaDA repos: https://github.com/ML-GSAI/LLaDA 17 | 18 | import torch 19 | import argparse 20 | 21 | from generate import generate, generate_with_prefix_cache, generate_with_dual_cache 22 | from transformers import AutoTokenizer, AutoModel 23 | from model.modeling_llada import LLaDAModelLM 24 | 25 | def chat(args): 26 | device = 'cuda' 27 | model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() 28 | tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) 29 | 30 | gen_length = args.gen_length 31 | steps = args.steps 32 | print('*' * 66) 33 | print(f'** Answer Length: {gen_length} | Sampling Steps: {steps} **') 34 | print('*' * 66) 35 | 36 | conversation_num = 0 37 | while True: 38 | user_input = input("Enter your question: ") 39 | 40 | m = [{"role": "user", "content": user_input}] 41 | user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) 42 | input_ids = tokenizer(user_input)['input_ids'] 43 | input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) 44 | 45 | if conversation_num == 0: 46 | prompt = input_ids 47 | else: 48 | prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1) 49 | print(f'use cache: {args.use_cache} use cache position: {args.if_cache_position} threshold: {args.threshold} block size: {args.block_size}') 50 | if args.use_cache: 51 | if args.if_cache_position: 52 | out, nfe = generate_with_dual_cache(model, prompt, steps=steps, gen_length=gen_length, block_length=args.block_size, temperature=0., remasking='low_confidence', threshold=args.threshold) 53 | else: 54 | out, nfe = generate_with_prefix_cache(model, prompt, steps=steps, gen_length=gen_length, block_length=args.block_size, temperature=0., remasking='low_confidence', threshold=args.threshold) 55 | else: 56 | out, nfe = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=args.block_size, temperature=0., remasking='low_confidence', threshold=args.threshold) 57 | 58 | answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0] 59 | print(f"Bot's reply: {answer}") 60 | print(f"Number of forward passes: {nfe}") 61 | 62 | # remove the 63 | prompt = out[out != 126081].unsqueeze(0) 64 | conversation_num += 1 65 | print('-----------------------------------------------------------------------') 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("--gen_length", type=int, default=128) 71 | parser.add_argument("--steps", type=int, default=128) 72 | parser.add_argument("--block_size", type=int, default=32) 73 | parser.add_argument("--use_cache", action="store_true") 74 | parser.add_argument("--if_cache_position", action="store_true") 75 | parser.add_argument("--threshold", type=float, default=None) 76 | 77 | args = parser.parse_args() 78 | chat(args) 79 | 80 | -------------------------------------------------------------------------------- /llada/eval.md: -------------------------------------------------------------------------------- 1 | # LLaDA Model Evaluation Guide 2 | 3 | This document provides detailed instructions for evaluating the LLaDA model on GSM8K math problem solving and HumanEval code generation tasks. 4 | 5 | ## Environment Setup 6 | 7 | Before running any evaluation, set the following environment variables: 8 | ```bash 9 | export HF_ALLOW_CODE_EVAL=1 10 | export HF_DATASETS_TRUST_REMOTE_CODE=true 11 | ``` 12 | 13 | ## GSM8K Evaluation 14 | 15 | GSM8K is a dataset of 8,000 grade school math problems designed to evaluate mathematical reasoning capabilities. 16 | 17 | ### Common Parameters 18 | 19 | ```bash 20 | task=gsm8k 21 | length=256 22 | block_length=32 23 | num_fewshot=5 24 | steps=$((length / block_length)) 25 | ``` 26 | 27 | ### Evaluation Methods 28 | 29 | 1. **Baseline** 30 | ```bash 31 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 32 | --confirm_run_unsafe_code --model llada_dist \ 33 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},show_speed=True 34 | ``` 35 | 36 | 2. **Prefix Cache** 37 | ```bash 38 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 39 | --confirm_run_unsafe_code --model llada_dist \ 40 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},use_cache=True,show_speed=True 41 | ``` 42 | 43 | 3. **Parallel Generation** 44 | ```bash 45 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 46 | --confirm_run_unsafe_code --model llada_dist \ 47 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},threshold=0.9,show_speed=True 48 | ``` 49 | 50 | 4. **Prefix Cache + Parallel** 51 | ```bash 52 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 53 | --confirm_run_unsafe_code --model llada_dist \ 54 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,threshold=0.9,show_speed=True 55 | ``` 56 | 57 | 5. **Dual Cache + Parallel** 58 | ```bash 59 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 60 | --confirm_run_unsafe_code --model llada_dist \ 61 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,dual_cache=True,threshold=0.9,show_speed=True 62 | ``` 63 | 64 | ### Parameter Descriptions 65 | 66 | - `task`: Evaluation task (gsm8k) 67 | - `length`: Generation length 68 | - `block_length`: Block size for parallel generation 69 | - `num_fewshot`: Number of few-shot examples 70 | - `steps`: Number of generation steps 71 | - `use_cache`: Enable prefix cache 72 | - `dual_cache`: Enable dual cache 73 | - `threshold`: Confidence threshold for parallel generation 74 | - `show_speed`: Display speed metrics 75 | 76 | ## HumanEval Evaluation 77 | 78 | HumanEval is a dataset of 164 Python programming problems designed to evaluate code generation capabilities. 79 | 80 | ### Common Parameters 81 | 82 | ```bash 83 | task=humaneval 84 | length=256 85 | block_length=32 86 | steps=$((length / block_length)) 87 | ``` 88 | 89 | ### Evaluation Methods 90 | 91 | 1. **Baseline** 92 | ```bash 93 | accelerate launch eval_llada.py --tasks ${task} \ 94 | --confirm_run_unsafe_code --model llada_dist \ 95 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},show_speed=True \ 96 | --output_path evals_results/baseline/humaneval-ns0-${length} --log_samples 97 | ``` 98 | 99 | 2. **Prefix Cache** 100 | ```bash 101 | accelerate launch eval_llada.py --tasks ${task} \ 102 | --confirm_run_unsafe_code --model llada_dist \ 103 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},use_cache=True,show_speed=True \ 104 | --output_path evals_results/prefix_cache/humaneval-ns0-${length} --log_samples 105 | ``` 106 | 107 | 3. **Parallel Generation** 108 | ```bash 109 | accelerate launch eval_llada.py --tasks ${task} \ 110 | --confirm_run_unsafe_code --model llada_dist \ 111 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},threshold=0.9,show_speed=True \ 112 | --output_path evals_results/parallel/humaneval-ns0-${length} --log_samples 113 | ``` 114 | 115 | 4. **Prefix Cache + Parallel** 116 | ```bash 117 | accelerate launch eval_llada.py --tasks ${task} \ 118 | --confirm_run_unsafe_code --model llada_dist \ 119 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,threshold=0.9,show_speed=True \ 120 | --output_path evals_results/cache_parallel/humaneval-ns0-${length} --log_samples 121 | ``` 122 | 123 | 5. **Dual Cache + Parallel** 124 | ```bash 125 | accelerate launch eval_llada.py --tasks ${task} \ 126 | --confirm_run_unsafe_code --model llada_dist \ 127 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,dual_cache=True,threshold=0.9,show_speed=True \ 128 | --output_path evals_results/dual_cache_parallel/humaneval-ns0-${length} --log_samples 129 | ``` 130 | 131 | ### Post-processing 132 | 133 | For HumanEval evaluation, post-processing is required: 134 | ```bash 135 | python postprocess_code.py {the samples_xxx.jsonl file under output_path} 136 | ``` 137 | 138 | ## Notes 139 | 140 | 1. All evaluations use the LLaDA-8B-Instruct model 141 | 2. Results are saved in the `evals_results` directory 142 | 3. For HumanEval, samples are logged for post-processing 143 | 4. Speed metrics are shown for all evaluations 144 | 5. Different optimization strategies can be combined: -------------------------------------------------------------------------------- /llada/eval_gsm8k.sh: -------------------------------------------------------------------------------- 1 | # Set the environment variables first before running the command. 2 | export HF_ALLOW_CODE_EVAL=1 3 | export HF_DATASETS_TRUST_REMOTE_CODE=true 4 | 5 | task=gsm8k 6 | length=256 7 | block_length=32 8 | num_fewshot=5 9 | steps=$((length / block_length)) 10 | 11 | # baseline 12 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 13 | --confirm_run_unsafe_code --model llada_dist \ 14 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},show_speed=True 15 | 16 | 17 | # prefix cache 18 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 19 | --confirm_run_unsafe_code --model llada_dist \ 20 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},use_cache=True,show_speed=True 21 | 22 | 23 | # parallel 24 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 25 | --confirm_run_unsafe_code --model llada_dist \ 26 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},threshold=0.9,show_speed=True 27 | 28 | 29 | # prefix cache+parallel 30 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 31 | --confirm_run_unsafe_code --model llada_dist \ 32 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,threshold=0.9,show_speed=True 33 | 34 | # dual cache+parallel 35 | accelerate launch eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \ 36 | --confirm_run_unsafe_code --model llada_dist \ 37 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,dual_cache=True,threshold=0.9,show_speed=True 38 | -------------------------------------------------------------------------------- /llada/eval_humaneval.sh: -------------------------------------------------------------------------------- 1 | # Set the environment variables first before running the command. 2 | export HF_ALLOW_CODE_EVAL=1 3 | export HF_DATASETS_TRUST_REMOTE_CODE=true 4 | 5 | task=humaneval 6 | length=256 7 | block_length=32 8 | steps=$((length / block_length)) 9 | 10 | # baseline 11 | accelerate launch eval_llada.py --tasks ${task} \ 12 | --confirm_run_unsafe_code --model llada_dist \ 13 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},show_speed=True \ 14 | --output_path evals_results/baseline/humaneval-ns0-${length} --log_samples 15 | 16 | # prefix cache 17 | accelerate launch eval_llada.py --tasks ${task} \ 18 | --confirm_run_unsafe_code --model llada_dist \ 19 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${length},block_length=${block_length},use_cache=True,show_speed=True \ 20 | --output_path evals_results/prefix_cache/humaneval-ns0-${length} --log_samples 21 | 22 | # parallel 23 | accelerate launch eval_llada.py --tasks ${task} \ 24 | --confirm_run_unsafe_code --model llada_dist \ 25 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},threshold=0.9,show_speed=True \ 26 | --output_path evals_results/parallel/humaneval-ns0-${length} --log_samples 27 | 28 | # prefix cache+parallel 29 | accelerate launch eval_llada.py --tasks ${task} \ 30 | --confirm_run_unsafe_code --model llada_dist \ 31 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,threshold=0.9,show_speed=True \ 32 | --output_path evals_results/cache_parallel/humaneval-ns0-${length} --log_samples 33 | 34 | # dual cache+parallel 35 | accelerate launch eval_llada.py --tasks ${task} \ 36 | --confirm_run_unsafe_code --model llada_dist \ 37 | --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=${length},steps=${steps},block_length=${block_length},use_cache=True,dual_cache=True,threshold=0.9,show_speed=True \ 38 | --output_path evals_results/dual_cache_parallel/humaneval-ns0-${length} --log_samples 39 | 40 | ## NOTICE: use postprocess for humaneval 41 | python postprocess_code.py {the samples_xxx.jsonl file under output_path} 42 | -------------------------------------------------------------------------------- /llada/eval_llada.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from LLaDA repos: https://github.com/ML-GSAI/LLaDA 17 | 18 | ''' 19 | This file is inspired by the code from https://github.com/ML-GSAI/SMDM 20 | ''' 21 | import accelerate 22 | import torch 23 | import re 24 | from pathlib import Path 25 | import random 26 | import numpy as np 27 | import torch.nn.functional as F 28 | from datasets import Dataset 29 | from lm_eval.__main__ import cli_evaluate 30 | from lm_eval.api.instance import Instance 31 | from lm_eval.api.model import LM 32 | from lm_eval.api.registry import register_model 33 | from tqdm import tqdm 34 | import os 35 | from transformers import AutoTokenizer, AutoModel, AutoConfig 36 | from generate import generate, generate_with_prefix_cache, generate_with_dual_cache 37 | from model.modeling_llada import LLaDAModelLM 38 | import json 39 | import time 40 | def set_seed(seed): 41 | torch.manual_seed(seed) 42 | random.seed(seed) 43 | np.random.seed(seed) 44 | 45 | torch.backends.cudnn.deterministic = True 46 | torch.backends.cudnn.benchmark = False 47 | 48 | 49 | @register_model("llada_dist") 50 | class LLaDAEvalHarness(LM): 51 | def __init__( 52 | self, 53 | model_path='', 54 | mask_id=126336, 55 | max_length=4096, 56 | batch_size=32, 57 | mc_num=128, 58 | is_check_greedy=True, 59 | steps=1024, 60 | gen_length=1024, 61 | block_length=1024, 62 | remasking='low_confidence', 63 | device="cuda", 64 | use_cache=False, 65 | threshold=None, 66 | save_dir=None, 67 | show_speed=False, 68 | dual_cache=False, 69 | **kwargs, 70 | ): 71 | ''' 72 | Args: 73 | model_path: LLaDA-8B-Base model path. 74 | mask_id: The token id of [MASK] is 126336. 75 | max_length: the max sequence length. 76 | batch_size: mini batch size. 77 | mc_num: Monte Carlo estimation iterations 78 | is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer 79 | is generated through greedy sampling conditioned on the prompt (note that this differs from conditional 80 | generation). We implement this verification through the suffix_greedy_prediction() function, which 81 | returns a True/False judgment used for accuracy calculation. 82 | When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function. 83 | However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality, 84 | we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False 85 | by default, significantly accelerating the evaluation process. 86 | cfg_scale: Unsupervised classifier-free guidance scale. 87 | ''' 88 | super().__init__() 89 | 90 | accelerator = accelerate.Accelerator() 91 | if accelerator.num_processes > 1: 92 | self.accelerator = accelerator 93 | else: 94 | self.accelerator = None 95 | 96 | model_kwargs = {} 97 | if self.accelerator is not None: 98 | model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}}) 99 | config = AutoConfig.from_pretrained(model_path) 100 | config.flash_attention = True 101 | self.model = LLaDAModelLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config=config, **model_kwargs) 102 | self.model.eval() 103 | 104 | self.device = torch.device(device) 105 | if self.accelerator is not None: 106 | self.model = self.accelerator.prepare(self.model) 107 | self.device = torch.device(f'{self.accelerator.device}') 108 | self._rank = self.accelerator.local_process_index 109 | self._world_size = self.accelerator.num_processes 110 | else: 111 | self.model = self.model.to(device) 112 | 113 | self.mask_id = mask_id 114 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 115 | 116 | self.mc_num = mc_num 117 | self.batch_size = int(batch_size) 118 | assert mc_num % self.batch_size == 0 119 | self.sampling_eps = 0. 120 | self.max_length = max_length 121 | self.is_check_greedy = is_check_greedy 122 | 123 | self.steps = steps 124 | self.gen_length = gen_length 125 | self.block_length = block_length 126 | self.remasking = remasking 127 | self.use_cache = use_cache 128 | self.threshold = threshold 129 | self.is_instruct = True if 'instruct' in model_path.lower() else False 130 | self.save_dir = save_dir 131 | self.show_speed = show_speed 132 | self.dual_cache = dual_cache 133 | @property 134 | def rank(self): 135 | return self._rank 136 | 137 | @property 138 | def world_size(self): 139 | return self._world_size 140 | 141 | def _forward_process(self, batch, prompt_index): 142 | b, l = batch.shape 143 | 144 | target_len = (l - prompt_index.sum()).item() 145 | k = torch.randint(1, target_len + 1, (), device=batch.device) 146 | 147 | x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() 148 | x = ((x - 1) % target_len) + 1 149 | assert x.min() >= 1 and x.max() <= target_len 150 | 151 | indices = torch.arange(target_len, device=batch.device).repeat(b, 1) 152 | is_mask = indices < x.unsqueeze(1) 153 | 154 | for i in range(b): 155 | is_mask[i] = is_mask[i][torch.randperm(target_len)] 156 | 157 | is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1) 158 | 159 | noisy_batch = torch.where(is_mask, self.mask_id, batch) 160 | 161 | return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) 162 | 163 | @torch.no_grad() 164 | def get_logits(self, batch, prompt_index): 165 | if self.cfg > 0.: 166 | assert len(prompt_index) == batch.shape[1] 167 | prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) 168 | un_batch = batch.clone() 169 | un_batch[prompt_index] = self.mask_id 170 | batch = torch.cat([batch, un_batch]) 171 | 172 | logits = self.model(batch).logits 173 | 174 | if self.cfg > 0.: 175 | logits, un_logits = torch.chunk(logits, 2, dim=0) 176 | logits = un_logits + (self.cfg + 1) * (logits - un_logits) 177 | return logits[:, :batch.shape[1]] 178 | 179 | @torch.no_grad() 180 | def get_loglikelihood(self, prefix, target): 181 | seq = torch.concatenate([prefix, target])[None, :] 182 | seq = seq.repeat((self.batch_size, 1)).to(self.device) 183 | 184 | prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) 185 | 186 | loss_acc = [] 187 | for _ in range(self.mc_num // self.batch_size): 188 | perturbed_seq, p_mask = self._forward_process(seq, prompt_index) 189 | 190 | mask_indices = perturbed_seq == self.mask_id 191 | 192 | logits = self.get_logits(perturbed_seq, prompt_index) 193 | 194 | loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices] 195 | loss = loss.sum() / self.batch_size 196 | loss_acc.append(loss.item()) 197 | 198 | return - sum(loss_acc) / len(loss_acc) 199 | 200 | @torch.no_grad() 201 | def suffix_greedy_prediction(self, prefix, target): 202 | if not self.is_check_greedy: 203 | return False 204 | 205 | seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device) 206 | prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) 207 | prefix, target = prefix.to(self.device), target.to(self.device) 208 | seq[0, :len(prefix)] = prefix 209 | 210 | for i in range(len(target)): 211 | mask_index = (seq == self.mask_id) 212 | logits = self.get_logits(seq, prompt_index)[mask_index] 213 | x0 = torch.argmax(logits, dim=-1) 214 | 215 | p = torch.softmax(logits.to(torch.float32), dim=-1) 216 | confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1) 217 | _, index = torch.sort(confidence, descending=True) 218 | x0[index[1:]] = self.mask_id 219 | seq[mask_index] = x0.clone() 220 | correct = target == seq[0, len(prefix):] 221 | correct = torch.all(correct) 222 | return correct 223 | 224 | def _encode_pair(self, context, continuation): 225 | n_spaces = len(context) - len(context.rstrip()) 226 | if n_spaces > 0: 227 | continuation = context[-n_spaces:] + continuation 228 | context = context[:-n_spaces] 229 | 230 | whole_enc = self.tokenizer(context + continuation)["input_ids"] 231 | context_enc = self.tokenizer(context)["input_ids"] 232 | 233 | context_enc_len = len(context_enc) 234 | continuation_enc = whole_enc[context_enc_len:] 235 | 236 | return context_enc, continuation_enc 237 | 238 | def loglikelihood(self, requests): 239 | def _tokenize(e): 240 | prefix, target = self._encode_pair(e["prefix"], e["target"]) 241 | return { 242 | "prefix_text": e["prefix"], 243 | "target_text": e["target"], 244 | "prefix": prefix, 245 | "target": target, 246 | } 247 | 248 | ds = [] 249 | ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] 250 | ds = Dataset.from_list(ds) 251 | ds = ds.map(_tokenize) 252 | ds = ds.with_format("torch") 253 | prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] 254 | 255 | assert max(prompt_len) <= 4096 256 | 257 | out = [] 258 | with torch.no_grad(): 259 | for elem in tqdm(ds, desc="Computing likelihood..."): 260 | prefix = elem["prefix"] 261 | target = elem["target"] 262 | 263 | ll = self.get_loglikelihood(prefix, target) 264 | 265 | is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) 266 | 267 | out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) 268 | torch.cuda.empty_cache() 269 | return out 270 | 271 | def loglikelihood_rolling(self, requests): 272 | raise NotImplementedError 273 | 274 | 275 | def generate_until(self, requests): 276 | output = [] 277 | num_tokens = 0 278 | num_nfe = 0 279 | processed_count = 0 280 | if self.save_dir is not None: 281 | os.makedirs(self.save_dir, exist_ok=True) 282 | rank = self.rank 283 | save_path = os.path.join(self.save_dir, f'rank_{rank}.jsonl') 284 | print(f"save_path: {save_path}") 285 | if os.path.exists(save_path): 286 | print(f"load from {save_path}") 287 | with open(save_path, 'r', encoding='utf-8') as f: 288 | output = [json.loads(line) for line in f] 289 | processed_count = len(output) 290 | print(f"processed_count: {processed_count}") 291 | start_time = time.time() 292 | for i, req in enumerate(tqdm(requests, desc="Generating...")): 293 | if i < processed_count: 294 | continue 295 | 296 | question = req.args[0] 297 | if self.is_instruct: 298 | m = [{"role": "user", "content": question}] 299 | user_input = self.tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) 300 | input_ids = self.tokenizer(user_input)['input_ids'] 301 | else: 302 | user_input = question 303 | input_ids = self.tokenizer(user_input)['input_ids'] 304 | 305 | stop_tokens = req.args[1]['until'] 306 | input_ids = torch.tensor(input_ids).to(self.device).unsqueeze(0) 307 | if self.use_cache: 308 | if self.dual_cache: 309 | generated_answer, nfe = generate_with_dual_cache(self.model, input_ids, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length, 310 | temperature=0, remasking=self.remasking, mask_id=self.mask_id, threshold=self.threshold) 311 | else: 312 | generated_answer, nfe = generate_with_prefix_cache(self.model, input_ids, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length, 313 | temperature=0, remasking=self.remasking, mask_id=self.mask_id, threshold=self.threshold) 314 | else: 315 | generated_answer, nfe = generate(self.model, input_ids, steps=self.steps, gen_length=self.gen_length, block_length=self.block_length, 316 | temperature=0, remasking=self.remasking, mask_id=self.mask_id, threshold=self.threshold) 317 | 318 | if self.is_instruct and 'task_id' in req.doc and str(req.doc['task_id']).lower().startswith('humaneval'): 319 | if self.show_speed: 320 | num_tokens += (generated_answer != 126081).sum() 321 | num_nfe += nfe 322 | generated_answer = self.tokenizer.decode(generated_answer[0][input_ids.shape[1]:], skip_special_tokens=True) 323 | else: 324 | generated_answer = self.tokenizer.decode(generated_answer[0][input_ids.shape[1]:], skip_special_tokens=False) 325 | for stop_seq in stop_tokens: 326 | if stop_seq in generated_answer: 327 | generated_answer = generated_answer.split(stop_seq)[0] 328 | 329 | # remove special tokens 330 | generated_answer_ids = torch.tensor(self.tokenizer(generated_answer)["input_ids"]) 331 | if self.show_speed: 332 | num_tokens += (generated_answer_ids != 126081).sum() 333 | num_nfe += nfe 334 | generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) 335 | output.append(generated_answer) 336 | 337 | if self.save_dir is not None: 338 | # 增量保存新生成的答案 339 | with open(save_path, 'a', encoding='utf-8') as f: 340 | f.write(json.dumps(generated_answer, ensure_ascii=False) + '\n') 341 | 342 | print('=' * 20) 343 | print('question: ', question) 344 | print('answer: ', generated_answer) 345 | print('=' * 20, end='\n\n') 346 | # self.accelerator.wait_for_everyone() 347 | end_time = time.time() 348 | if self.show_speed: 349 | print(f"Total number of tokens generated: {num_tokens}") 350 | print(f"Total time taken: {end_time - start_time} seconds") 351 | print(f"Tokens per second: {num_tokens / (end_time - start_time)}") 352 | print(f"Total NFE is {num_nfe}") 353 | return output 354 | 355 | 356 | if __name__ == "__main__": 357 | set_seed(1234) 358 | cli_evaluate() 359 | -------------------------------------------------------------------------------- /llada/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from LLaDA repos: https://github.com/ML-GSAI/LLaDA 17 | 18 | import torch 19 | import numpy as np 20 | import torch.nn.functional as F 21 | import os 22 | from transformers import AutoTokenizer, AutoModel 23 | from model.modeling_llada import LLaDAModelLM 24 | 25 | def add_gumbel_noise(logits, temperature): 26 | ''' 27 | The Gumbel max is a method for sampling categorical distributions. 28 | According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. 29 | Thus, we use float64. 30 | ''' 31 | if temperature == 0: 32 | return logits 33 | logits = logits.to(torch.float64) 34 | noise = torch.rand_like(logits, dtype=torch.float64) 35 | gumbel_noise = (- torch.log(noise)) ** temperature 36 | return logits.exp() / gumbel_noise 37 | 38 | 39 | def get_num_transfer_tokens(mask_index, steps): 40 | ''' 41 | In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. 42 | Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), 43 | the expected number of tokens transitioned at each step should be consistent. 44 | 45 | This function is designed to precompute the number of tokens that need to be transitioned at each step. 46 | ''' 47 | mask_num = mask_index.sum(dim=1, keepdim=True) 48 | 49 | base = mask_num // steps 50 | remainder = mask_num % steps 51 | 52 | num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base 53 | 54 | for i in range(mask_num.size(0)): 55 | num_transfer_tokens[i, :remainder[i]] += 1 56 | 57 | return num_transfer_tokens 58 | 59 | 60 | @ torch.no_grad() 61 | def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., 62 | remasking='low_confidence', mask_id=126336, threshold=None): 63 | ''' 64 | Args: 65 | model: Mask predictor. 66 | prompt: A tensor of shape (1, L). 67 | steps: Sampling steps, less than or equal to gen_length. 68 | gen_length: Generated answer length. 69 | block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. 70 | temperature: Categorical distribution sampling temperature. 71 | cfg_scale: Unsupervised classifier-free guidance scale. 72 | remasking: Remasking strategy. 'low_confidence' or 'random'. 73 | mask_id: The toke id of [MASK] is 126336. 74 | ''' 75 | x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) 76 | x[:, :prompt.shape[1]] = prompt.clone() 77 | 78 | assert gen_length % block_length == 0 79 | num_blocks = gen_length // block_length 80 | 81 | assert steps % num_blocks == 0 82 | steps = steps // num_blocks 83 | 84 | nfe = 0 85 | for num_block in range(num_blocks): 86 | block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id) 87 | num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) 88 | i = 0 89 | while True: 90 | nfe += 1 91 | mask_index = (x == mask_id) 92 | logits = model(x).logits 93 | mask_index[:, prompt.shape[1] + (num_block + 1) * block_length:] = 0 94 | x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens[:, i] if threshold is None else None, threshold) 95 | x[transfer_index] = x0[transfer_index] 96 | i += 1 97 | if (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id).sum() == 0: 98 | break 99 | return x, nfe 100 | 101 | 102 | 103 | @ torch.no_grad() 104 | def generate_with_prefix_cache(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., 105 | remasking='low_confidence', mask_id=126336, threshold=None): 106 | ''' 107 | Args: 108 | model: Mask predictor. 109 | prompt: A tensor of shape (1, L). 110 | steps: Sampling steps, less than or equal to gen_length. 111 | gen_length: Generated answer length. 112 | block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. 113 | temperature: Categorical distribution sampling temperature. 114 | cfg_scale: Unsupervised classifier-free guidance scale. 115 | remasking: Remasking strategy. 'low_confidence' or 'random'. 116 | mask_id: The toke id of [MASK] is 126336. 117 | ''' 118 | x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) 119 | x[:, :prompt.shape[1]] = prompt.clone() 120 | 121 | assert gen_length % block_length == 0 122 | num_blocks = gen_length // block_length 123 | 124 | assert steps % num_blocks == 0 125 | steps = steps // num_blocks 126 | 127 | nfe = 0 128 | 129 | for num_block in range(num_blocks): 130 | current_block_start = prompt.shape[1] + num_block * block_length 131 | current_block_end = current_block_start + block_length 132 | 133 | block_mask_index = (x[:, current_block_start:current_block_end] == mask_id) 134 | num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) 135 | 136 | output = model(x, use_cache=True) 137 | past_key_values = output.past_key_values 138 | 139 | mask_index = (x == mask_id) 140 | mask_index[:, current_block_end:] = 0 141 | x0, transfer_index = get_transfer_index(output.logits, temperature, remasking, mask_index, x, num_transfer_tokens[:, 0] if threshold is None else None, threshold) 142 | x[transfer_index] = x0[transfer_index] 143 | 144 | new_past_key_values = [] 145 | for i in range(len(past_key_values)): 146 | new_past_key_values.append(()) 147 | for j in range(len(past_key_values[i])): 148 | new_past_key_values[i] += (past_key_values[i][j][:, :, :current_block_start],) 149 | 150 | past_key_values = new_past_key_values 151 | nfe += 1 152 | 153 | i = 1 154 | while True: 155 | nfe += 1 156 | mask_index = (x[:, current_block_start:] == mask_id) 157 | mask_index[:, block_length:] = 0 158 | 159 | logits = model(x[:, current_block_start:], past_key_values=past_key_values, use_cache=True).logits 160 | 161 | logits_with_noise = add_gumbel_noise(logits, temperature=temperature) 162 | x0 = torch.argmax(logits_with_noise, dim=-1) # b, l 163 | 164 | x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index, 165 | x[:, current_block_start:], num_transfer_tokens[:, i] if threshold is None else None, threshold) 166 | x[:, current_block_start:][transfer_index] = x0[transfer_index] 167 | if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0: 168 | break 169 | i += 1 170 | 171 | 172 | return x, nfe 173 | 174 | 175 | @ torch.no_grad() 176 | def generate_with_dual_cache(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., 177 | remasking='low_confidence', mask_id=126336, threshold=None): 178 | ''' 179 | Args: 180 | model: Mask predictor. 181 | prompt: A tensor of shape (1, L). 182 | steps: Sampling steps, less than or equal to gen_length. 183 | gen_length: Generated answer length. 184 | block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. 185 | temperature: Categorical distribution sampling temperature. 186 | cfg_scale: Unsupervised classifier-free guidance scale. 187 | remasking: Remasking strategy. 'low_confidence' or 'random'. 188 | mask_id: The toke id of [MASK] is 126336. 189 | ''' 190 | x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) 191 | x[:, :prompt.shape[1]] = prompt.clone() 192 | 193 | assert gen_length % block_length == 0 194 | num_blocks = gen_length // block_length 195 | 196 | assert steps % num_blocks == 0 197 | steps = steps // num_blocks 198 | 199 | nfe = 0 200 | for num_block in range(num_blocks): 201 | current_block_start = prompt.shape[1] + num_block * block_length 202 | current_block_end = current_block_start + block_length 203 | 204 | block_mask_index = (x[:, current_block_start:current_block_end] == mask_id) 205 | num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) 206 | 207 | # cache init and update 208 | output = model(x, use_cache=True) 209 | past_key_values = output.past_key_values 210 | mask_index = (x == mask_id) 211 | mask_index[:, current_block_end:] = 0 212 | x0, transfer_index = get_transfer_index(output.logits, temperature, remasking, mask_index, x, num_transfer_tokens[:, 0] if threshold is None else None, threshold) 213 | x[transfer_index] = x0[transfer_index] 214 | nfe += 1 215 | 216 | i = 1 217 | replace_position = torch.zeros_like(x, dtype=torch.bool) 218 | replace_position[:, current_block_start:current_block_end] = 1 219 | while True: 220 | nfe += 1 221 | mask_index = (x[:, current_block_start:current_block_end] == mask_id) 222 | # cache position is the position between current_block_start and current_block_end 223 | logits = model(x[:, current_block_start:current_block_end], past_key_values=past_key_values, use_cache=True, replace_position=replace_position).logits 224 | 225 | x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index, 226 | x[:, current_block_start:current_block_end], num_transfer_tokens[:, i] if threshold is None else None, threshold) 227 | x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index] 228 | if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0: 229 | break 230 | i += 1 231 | 232 | return x, nfe 233 | 234 | 235 | def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None): 236 | logits_with_noise = add_gumbel_noise(logits, temperature=temperature) 237 | x0 = torch.argmax(logits_with_noise, dim=-1) # b, l 238 | 239 | if remasking == 'low_confidence': 240 | p = F.softmax(logits.to(torch.float64), dim=-1) 241 | x0_p = torch.squeeze( 242 | torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l 243 | elif remasking == 'random': 244 | x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) 245 | else: 246 | raise NotImplementedError(remasking) 247 | 248 | x0 = torch.where(mask_index, x0, x) 249 | confidence = torch.where(mask_index, x0_p, -np.inf) 250 | 251 | transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) 252 | if threshold is not None: 253 | num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) 254 | for j in range(confidence.shape[0]): 255 | _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j]) 256 | transfer_index[j, select_index] = True 257 | if threshold is not None: 258 | for k in range(1, num_transfer_tokens[j]): 259 | if confidence[j, select_index[k]] < threshold: 260 | transfer_index[j, select_index[k]] = False 261 | return x0, transfer_index 262 | 263 | def main(): 264 | device = 'cuda' 265 | 266 | model = LLaDAModelLM.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() 267 | tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True) 268 | 269 | prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" 270 | 271 | # Add special tokens for the Instruct model. The Base model does not require the following two lines. 272 | m = [{"role": "user", "content": prompt}, ] 273 | prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) 274 | 275 | input_ids = tokenizer(prompt)['input_ids'] 276 | input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) 277 | 278 | out = generate_with_dual_cache(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., remasking='low_confidence') 279 | print(tokenizer.batch_decode(out[0][:, input_ids.shape[1]:], skip_special_tokens=True)[0]) 280 | 281 | if __name__ == '__main__': 282 | main() 283 | -------------------------------------------------------------------------------- /llada/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from LLaDA repos: https://github.com/ML-GSAI/LLaDA 17 | 18 | from .configuration_llada import LLaDAConfig 19 | from .modeling_llada import LLaDAModelLM 20 | 21 | __all__ = ['LLaDAConfig', 'LLaDAModelLM'] -------------------------------------------------------------------------------- /llada/model/configuration_llada.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from LLaDA repos: https://github.com/ML-GSAI/LLaDA 17 | 18 | """ 19 | LLaDA configuration 20 | """ 21 | from transformers import AutoConfig, PretrainedConfig 22 | 23 | from enum import Enum 24 | from os import PathLike 25 | from typing import Union 26 | from dataclasses import asdict, dataclass, field 27 | from glob import glob 28 | from pathlib import Path 29 | from typing import ( 30 | Any, 31 | Dict, 32 | Iterable, 33 | List, 34 | Optional, 35 | Tuple, 36 | Type, 37 | TypeVar, 38 | Union, 39 | cast, 40 | ) 41 | 42 | 43 | __all__ = [ 44 | "ActivationType", 45 | "ActivationCheckpointingStrategy", 46 | "BlockType", 47 | "LayerNormType", 48 | "InitFnType", 49 | "ModelConfig", 50 | ] 51 | 52 | PathOrStr = Union[str, PathLike] 53 | 54 | 55 | class StrEnum(str, Enum): 56 | """ 57 | This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. 58 | We include this here for compatibility with older version of Python. 59 | """ 60 | 61 | def __str__(self) -> str: 62 | return self.value 63 | 64 | def __repr__(self) -> str: 65 | return f"'{str(self)}'" 66 | 67 | 68 | class LayerNormType(StrEnum): 69 | default = "default" 70 | """ 71 | The default LayerNorm implementation, equivalent to PyTorch's built-in version. 72 | """ 73 | 74 | low_precision = "low_precision" 75 | """ 76 | A low-precision version of the default LayerNorm. 77 | """ 78 | 79 | rms = "rms" 80 | """ 81 | An RMSNorm implementation. When using ``torch.compile`` this is 82 | probably the fastest implementation. 83 | """ 84 | 85 | gemma_rms = "gemma_rms" 86 | """ 87 | An RMSNorm implementation by gemmma. When using ``torch.compile`` this is 88 | probably the fastest implementation. 89 | """ 90 | 91 | amd_compatible = "amd_compatible" 92 | """ 93 | LayerNorm implemented manually to work around an issue with ROCm. 94 | """ 95 | 96 | 97 | class ActivationType(StrEnum): 98 | gelu = "gelu" 99 | relu = "relu" 100 | silu = "silu" 101 | swiglu = "swiglu" 102 | 103 | 104 | class BlockType(StrEnum): 105 | sequential = "sequential" 106 | parallel = "parallel" 107 | 108 | llama = "llama" 109 | """ 110 | A block similar to the sequential block with slightly different 111 | implementations of operations like attention to imitate the behavior of Llama. 112 | """ 113 | 114 | 115 | class InitFnType(StrEnum): 116 | mitchell = "mitchell" 117 | """ 118 | The strategy suggested to us by Mitchell Wortsman from UW. 119 | This uses a truncated normal distribution with an adaptive standard deviation that depends 120 | on the size of the weights as well as the depth of the layer. 121 | """ 122 | 123 | normal = "normal" 124 | """ 125 | All weights are initialized from the same normal distribution. 126 | """ 127 | 128 | kaiming_normal = "kaiming_normal" 129 | """ 130 | All weights are initialized with the Kaiming method from a normal distribution. 131 | Note this currently won't work with FSDP. 132 | """ 133 | 134 | fan_in = "fan_in" 135 | """ 136 | "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` 137 | is the input dimensionality of the kernel. 138 | """ 139 | 140 | full_megatron = "full_megatron" 141 | """ 142 | This is what metaseq calls "full megatron init". It is the init used for Llama 2. 143 | """ 144 | 145 | 146 | @dataclass 147 | class ModelConfig(): 148 | """ 149 | LLaDA (model) configuration. 150 | """ 151 | 152 | # Note that the defaults for these attributes are equivalent to the base GPT2 model. 153 | 154 | d_model: int = 768 155 | """ 156 | The hidden size of the model. 157 | """ 158 | 159 | n_heads: int = 12 160 | """ 161 | The number of self-attention heads. 162 | """ 163 | 164 | n_kv_heads: Optional[int] = None 165 | """ 166 | The number of heads to use for keys and values. Defaults to `n_heads`. 167 | Set this to ``None`` or ``n_heads`` for normal multi-head attention. 168 | Set this to 1 for multi-query attention. 169 | Set it to some in-between value for Llama2-style grouped query attention. 170 | """ 171 | 172 | n_layers: int = 12 173 | """ 174 | The number of layers/blocks. 175 | """ 176 | 177 | mlp_ratio: int = 4 178 | """ 179 | The ratio of the inner MLP dimensionality to ``d_model``. 180 | This is only used when ``mlp_hidden_size`` is not set. 181 | """ 182 | 183 | mlp_hidden_size: Optional[int] = None 184 | """ 185 | Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. 186 | """ 187 | 188 | activation_type: ActivationType = ActivationType.swiglu 189 | """ 190 | The activation function to use within the MLP layers. 191 | """ 192 | 193 | block_type: BlockType = BlockType.sequential 194 | """ 195 | The transformer block implementation. 196 | """ 197 | 198 | block_group_size: int = 1 199 | """ 200 | The number of blocks to group together into a single parent block. 201 | This has no affect on the number of parameters in the model and is only used to wrap groups 202 | of blocks together with a single FSDP wrapper during training. 203 | """ 204 | 205 | alibi: bool = False 206 | """ 207 | If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. 208 | """ 209 | 210 | alibi_bias_max: float = 8.0 211 | """ 212 | Maximum absolute value of ALiBi bias. 213 | """ 214 | 215 | rope: bool = False 216 | """ 217 | Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. 218 | """ 219 | 220 | rope_full_precision: bool = True 221 | """ 222 | If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, 223 | apply RoPE at the precision of the input. 224 | """ 225 | 226 | flash_attention: bool = False 227 | """ 228 | If ``True``, use ``FlashAttention``. 229 | """ 230 | 231 | attention_dropout: float = 0.1 232 | """ 233 | The dropout probability within the attention modules. 234 | """ 235 | 236 | multi_query_attention: Optional[bool] = None 237 | """ 238 | Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters 239 | and is more efficient during inference. 240 | """ 241 | 242 | attention_layer_norm: bool = False 243 | """ 244 | Apply layer norm to the keys and queries within the attention mechanism. 245 | This can help stabilize training. 246 | """ 247 | 248 | residual_dropout: float = 0.1 249 | """ 250 | The dropout probability for the MLP and attention output within each block. 251 | """ 252 | 253 | embedding_dropout: float = 0.1 254 | """ 255 | The dropout probability for embeddings. 256 | """ 257 | 258 | input_emb_norm: bool = False 259 | """ 260 | An input hidden_states norm implementation by gemmma. 261 | """ 262 | 263 | layer_norm_type: LayerNormType = LayerNormType.default 264 | """ 265 | The layernorm implementation to use. 266 | """ 267 | 268 | layer_norm_with_affine: bool = True 269 | """ 270 | Whether to include bias and weight parameters for the layer norms. 271 | This only affects layer norms that are immediately followed by a linear layer in the forward pass, 272 | so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` 273 | to ``False``. 274 | """ 275 | 276 | rms_norm_eps: float = 1e-05 277 | """ 278 | The rms layernorm eps param. 279 | """ 280 | 281 | attention_layer_norm_with_affine: bool = True 282 | """ 283 | Toggle affine transform for the QK norms. 284 | """ 285 | 286 | max_sequence_length: int = 1024 287 | """ 288 | The maximum input sequence length supported by the model. 289 | """ 290 | 291 | train_max_sequence_length: int = 1024 292 | """ 293 | The maximum input sequence length supported by the model during training. 294 | """ 295 | 296 | rope_theta: float = 10000.0 297 | """ 298 | The rope base param. 299 | """ 300 | 301 | include_qkv_bias: Optional[bool] = False 302 | """ 303 | Whether or not to include bias parameters in qkv linear layers. 304 | """ 305 | 306 | include_bias: bool = False 307 | """ 308 | Whether or not to include bias parameters in linear layers. 309 | In PaLM, they got rid of all bias terms because they found that large 310 | models tend to have near 0 bias terms anyway. 311 | """ 312 | 313 | bias_for_layer_norm: Optional[bool] = None 314 | """ 315 | Whether or not to include bias parameters in layer norm. 316 | This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in 317 | layer norm. 318 | When this is None (the default), it inherits the setting from include_bias. 319 | """ 320 | 321 | scale_logits: bool = False 322 | """ 323 | If ``True``, scale the output logits by ``1 / sqrt(d_model)``. 324 | """ 325 | 326 | vocab_size: int = 50257 327 | """ 328 | Vocabulary size of the model. 329 | """ 330 | 331 | embedding_size: Optional[int] = 50304 332 | """ 333 | The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default 334 | to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the 335 | next multiple of 128 that's greater than ``vocab_size`` can improve throughput 336 | substantially. 337 | """ 338 | 339 | weight_tying: bool = True 340 | """ 341 | Whether to tie output linear weights to the input embedding. 342 | """ 343 | 344 | eos_token_id: int = 50256 345 | """ 346 | The ID of the end-of-sentence special token. 347 | """ 348 | 349 | pad_token_id: int = 50256 350 | """ 351 | The ID of the token to use for padding. Defaults to the ID of the EOS token. 352 | """ 353 | 354 | mask_token_id: Optional[int] = 50256 355 | """ 356 | The ID of the token to use for mask token. Defaults to the ID of the EOS token. 357 | """ 358 | 359 | init_device: Optional[str] = None 360 | """ 361 | The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". 362 | """ 363 | 364 | init_fn: InitFnType = InitFnType.normal 365 | """ 366 | The weight initialization strategy. 367 | """ 368 | 369 | init_std: float = 0.02 370 | """ 371 | The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such 372 | as "normal". 373 | """ 374 | 375 | init_cutoff_factor: Optional[float] = None 376 | """ 377 | A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such 378 | as "normal". Setting this to None means values are not cutoff. 379 | """ 380 | 381 | precision: Optional[str] = None 382 | """ 383 | Precision used to train/evaluate with. You shouldn't set this directly. 384 | See :data:`TrainConfig.precision` instead. 385 | """ 386 | 387 | @property 388 | def effective_n_kv_heads(self) -> int: 389 | if self.n_kv_heads is None: 390 | if self.multi_query_attention is True: 391 | return 1 392 | else: 393 | return self.n_heads 394 | else: 395 | if self.multi_query_attention is None: 396 | return self.n_kv_heads 397 | if self.multi_query_attention: 398 | n_kv_heads_should_be = 1 399 | else: 400 | n_kv_heads_should_be = self.n_heads 401 | if self.n_kv_heads == n_kv_heads_should_be: 402 | return n_kv_heads_should_be 403 | else: 404 | raise Exception( 405 | "You can't set `multi_query_attention` and `n_kv_heads` at the same time." 406 | ) 407 | 408 | class ActivationCheckpointingStrategy(StrEnum): 409 | whole_layer = "whole_layer" 410 | """ 411 | Checkpoint every transformer layer. 412 | """ 413 | 414 | one_in_two = "one_in_two" 415 | """ 416 | Checkpoint one in two transformer layers. 417 | """ 418 | 419 | one_in_three = "one_in_three" 420 | """ 421 | Checkpoint one in three transformer layers. 422 | """ 423 | 424 | one_in_four = "one_in_four" 425 | """ 426 | Checkpoint one in four transformer layers. 427 | """ 428 | 429 | two_in_three = "two_in_three" 430 | """ 431 | Checkpoint two out of every three transformer layers. 432 | """ 433 | 434 | three_in_four = "three_in_four" 435 | """ 436 | Checkpoint three out of four of every transformer layers. 437 | """ 438 | 439 | four_in_five = "four_in_five" 440 | """ 441 | Checkpoint four out of five of every transformer layers. 442 | """ 443 | 444 | nine_in_ten = "nine_in_ten" 445 | """ 446 | Checkpoint nine out of ten of every transformer layers. 447 | """ 448 | 449 | fine_grained = "fine_grained" 450 | """ 451 | Focus checkpointing on where it is cheap to recompute and saves most memory. 452 | """ 453 | 454 | 455 | class LLaDAConfig(PretrainedConfig): 456 | model_type = "llada" 457 | keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm 458 | 459 | def __init__(self, use_cache: bool = False, **kwargs): 460 | model_config = ModelConfig() 461 | all_kwargs = model_config.__dict__ 462 | all_kwargs.update(kwargs) 463 | all_kwargs.update({"use_cache": use_cache}) 464 | all_kwargs.update( 465 | { 466 | "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"]) 467 | } 468 | ) 469 | super().__init__(**all_kwargs) 470 | 471 | @property 472 | def num_attention_heads(self): 473 | return self.n_heads 474 | 475 | @property 476 | def num_hidden_layers(self): 477 | return self.n_layers 478 | 479 | @property 480 | def hidden_size(self): 481 | return self.d_model 482 | 483 | 484 | # Register the config class so that it is available for transformer pipelines, auto-loading etc. 485 | AutoConfig.register("llada", LLaDAConfig) 486 | -------------------------------------------------------------------------------- /llada/postprocess_code.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | import evaluate as hf_evaluate 19 | import os 20 | import sys 21 | from sanitize import sanitize 22 | 23 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 24 | pass_at_k = hf_evaluate.load("code_eval") 25 | 26 | def pass_at_1(references, predictions): 27 | return pass_at_k.compute( 28 | references=references, 29 | predictions=predictions, 30 | k=[1], 31 | )[0]["pass@1"] 32 | 33 | import json 34 | 35 | 36 | def read_jsonl(file_path): 37 | data = [] 38 | with open(file_path, 'r') as file: 39 | for line in file: 40 | data.append(json.loads(line)) 41 | return data 42 | 43 | file_path = sys.argv[1] 44 | data = read_jsonl(file_path) 45 | 46 | references = [sample['target'] for sample in data] 47 | 48 | predictions = [[sanitize(sample['doc']['prompt'] + "\n" + sample['resps'][0][0].split('```python\n', 1)[-1].split('```')[0], 49 | sample['doc']["entry_point"])] 50 | for sample in data] 51 | 52 | pass_at_1s = [pass_at_1([reference], [prediction]) for reference, prediction in zip(references, predictions)] 53 | print(sum(pass_at_1s)/len(pass_at_1s)) 54 | 55 | def write_jsonl(data, file_path): 56 | with open(file_path, 'w') as file: 57 | for item in data: 58 | file.write(json.dumps(item) + '\n') 59 | 60 | res = [{"task_id": sample['doc']['task_id'], "completion": pred, "pass_at_1": res} 61 | for sample, pred, res in zip(data, predictions, pass_at_1s)] 62 | write_jsonl(res, file_path+'.cleaned') -------------------------------------------------------------------------------- /llada/sanitize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | # Modified from Dream repos: https://github.com/HKUNLP/Dream 17 | 18 | """Post-processing LLM-generated Python code implemented using tree-sitter.""" 19 | 20 | import os 21 | import sys 22 | import pathlib 23 | 24 | ROOT = os.path.dirname(os.path.abspath(__file__)) 25 | sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))]) 26 | 27 | import ast 28 | import traceback 29 | 30 | from typing import Dict, List, Optional, Set, Tuple 31 | 32 | def refine_text(text: str) -> str: 33 | text = text.replace("\t", " ") 34 | text = text.replace("\r\n", "\n").replace("\r", "\n") 35 | return text.strip() + "\n" 36 | 37 | def syntax_check(code, verbose = False): 38 | try: 39 | ast.parse(code) 40 | return True 41 | except (SyntaxError, MemoryError): 42 | if verbose: 43 | traceback.print_exc() 44 | return False 45 | 46 | def extract_longest_valid_code(text: str) -> str: 47 | lines = text.splitlines() 48 | 49 | if len(lines) > 100: 50 | lines = lines[:100] 51 | max_valid_lines = 0 52 | max_valid_snippet = "" 53 | 54 | for i in range(len(lines)): 55 | for j in range(i, len(lines)): 56 | current_snippet = "\n".join(lines[i:j+1]) 57 | if syntax_check(current_snippet): 58 | valid_line_count = sum(1 for line in lines[i:j+1] if line.strip()) 59 | if valid_line_count > max_valid_lines: 60 | max_valid_lines = valid_line_count 61 | max_valid_snippet = current_snippet 62 | 63 | return max_valid_snippet 64 | 65 | def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: 66 | name2deps = {} 67 | for name, node in nodes: 68 | deps = set() 69 | stack = [node] 70 | while stack: 71 | current = stack.pop() 72 | for child in ast.iter_child_nodes(current): 73 | if isinstance(child, ast.Name): 74 | deps.add(child.id) 75 | elif isinstance(child, ast.Attribute): 76 | deps.add(child.attr) 77 | else: 78 | stack.append(child) 79 | name2deps[name] = deps 80 | return name2deps 81 | 82 | def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]: 83 | visited = set() 84 | to_visit = [entrypoint] 85 | 86 | while to_visit: 87 | current = to_visit.pop(0) 88 | if current not in visited: 89 | visited.add(current) 90 | to_visit.extend(call_graph.get(current, set()) - visited) 91 | 92 | return visited 93 | 94 | def get_definition_name(node: ast.AST) -> Optional[str]: 95 | if isinstance(node, (ast.FunctionDef, ast.ClassDef)): 96 | return node.name 97 | elif isinstance(node, ast.Assign): 98 | targets = node.targets 99 | if targets and isinstance(targets[0], ast.Name): 100 | return targets[0].id 101 | return None 102 | 103 | def has_return_statement(node: ast.AST) -> bool: 104 | return any(isinstance(n, ast.Return) for n in ast.walk(node)) 105 | 106 | def sanitize(text: str, entrypoint: Optional[str] = None) -> str: 107 | 108 | text = refine_text(text) 109 | 110 | # text = python_extract(text) 111 | 112 | code = extract_longest_valid_code(text) 113 | tree = ast.parse(code) 114 | 115 | definitions = {} 116 | 117 | imports = [] 118 | 119 | for node in tree.body: 120 | if isinstance(node, (ast.Import, ast.ImportFrom)): 121 | imports.append(node) 122 | elif isinstance(node, ast.ClassDef): 123 | name = node.name 124 | definitions[name] = ('class', node) 125 | elif isinstance(node, ast.FunctionDef): 126 | name = node.name 127 | if has_return_statement(node): 128 | definitions[name] = ('function', node) 129 | elif isinstance(node, ast.Assign): 130 | name = get_definition_name(node) 131 | if name: 132 | definitions[name] = ('variable', node) 133 | 134 | if entrypoint: 135 | name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) 136 | reachable = get_function_dependency(entrypoint, name2deps) 137 | 138 | sanitized_output = [] 139 | 140 | for node in imports: 141 | sanitized_output.append(ast.unparse(node)) 142 | 143 | for name, (_, node) in definitions.items(): 144 | if not entrypoint or name in reachable: 145 | sanitized_output.append(ast.unparse(node)) 146 | 147 | return "\n".join(sanitized_output) -------------------------------------------------------------------------------- /paper/fast_dllm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Fast-dLLM/80f83ed424a6b6816789906a08d3b084c095f4e0/paper/fast_dllm.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.49.0 2 | lm_eval==0.4.8 3 | accelerate==0.34.2 4 | antlr4-python3-runtime==4.11 5 | math_verify 6 | sympy 7 | hf_xet 8 | --------------------------------------------------------------------------------