├── .gitignore ├── LICENSE ├── README.md ├── accelerate_config.yaml ├── asset ├── paper.pdf ├── pipeline.png ├── radar.png └── test.jpg ├── demo_Dream.py ├── demo_LLaDA.py ├── demo_LLaDA_V.py ├── dllm_cache ├── __init__.py ├── cache │ ├── Cache.py │ ├── Config.py │ └── __init__.py └── hooks │ ├── __init__.py │ ├── cache_hook_Dream.py │ ├── cache_hook_LLaDA.py │ └── cache_hook_LLaDA_V.py ├── eval_model ├── Dream.py ├── LLaDA.py └── __init__.py ├── evaluation_script.py ├── install.sh ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── configuration_llada.py │ │ ├── llava_llada.py │ │ └── modeling_llada.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── hf_vision.py │ │ ├── imagebind.py │ │ ├── open_clip_encoder.py │ │ └── siglip_encoder.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── masked_drop.py │ │ ├── perceiver.py │ │ ├── qformer.py │ │ └── spatial_pool.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── gradio_multi_image.py │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── llava_trainer_eval.py │ ├── train.py │ ├── train_dpo.py │ └── train_mem.py └── utils.py ├── metrics ├── get_mmlu_acc.py └── humaneval_pass@1.py ├── requirements.txt ├── scripts ├── run_Dream_bbh_Instruct.sh ├── run_Dream_bbh_base.sh ├── run_Dream_gpqa_Instruct.sh ├── run_Dream_gpqa_base.sh ├── run_Dream_gsm8k_Instruct.sh ├── run_Dream_gsm8k_base.sh ├── run_Dream_humaneval_Instruct.sh ├── run_Dream_humaneval_base.sh ├── run_Dream_mbpp_Instruct.sh ├── run_Dream_mbpp_base.sh ├── run_Dream_minerva_math_Instruct.sh ├── run_Dream_minerva_math_base.sh ├── run_Dream_mmlu_generative_Instruct.sh ├── run_Dream_mmlu_generative_base.sh ├── run_Dream_mmlu_pro_Instruct.sh ├── run_Dream_mmlu_pro_base.sh ├── run_LLaDA_bbh_Instruct.sh ├── run_LLaDA_bbh_base.sh ├── run_LLaDA_gpqa_Instruct.sh ├── run_LLaDA_gpqa_base.sh ├── run_LLaDA_gsm8k_Instruct.sh ├── run_LLaDA_gsm8k_base.sh ├── run_LLaDA_humaneval_Instruct.sh ├── run_LLaDA_humaneval_base.sh ├── run_LLaDA_long_bench_Instruct.sh ├── run_LLaDA_mbpp_Instruct.sh ├── run_LLaDA_mbpp_base.sh ├── run_LLaDA_minerva_math_Instruct.sh ├── run_LLaDA_minerva_math_base.sh ├── run_LLaDA_mmlu_generative_Instruct.sh ├── run_LLaDA_mmlu_generative_base.sh ├── run_LLaDA_mmlu_pro_Instruct.sh └── run_LLaDA_mmlu_pro_base.sh └── utils ├── __init__.py ├── generate_function.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | !.vscode/settings.json 3 | !.vscode/tasks.json 4 | !.vscode/launch.json 5 | !.vscode/extensions.json 6 | !.vscode/*.code-snippets 7 | 8 | # Local History for Visual Studio Code 9 | .history/ 10 | 11 | # Built Visual Studio Code Extensions 12 | *.vsix 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # UV 111 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | #uv.lock 115 | 116 | # poetry 117 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 118 | # This is especially recommended for binary packages to ensure reproducibility, and is more 119 | # commonly ignored for libraries. 120 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 121 | #poetry.lock 122 | 123 | # pdm 124 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 125 | #pdm.lock 126 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 127 | # in version control. 128 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 129 | .pdm.toml 130 | .pdm-python 131 | .pdm-build/ 132 | 133 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 134 | __pypackages__/ 135 | 136 | # Celery stuff 137 | celerybeat-schedule 138 | celerybeat.pid 139 | 140 | # SageMath parsed files 141 | *.sage.py 142 | 143 | # Environments 144 | .env 145 | .venv 146 | env/ 147 | venv/ 148 | ENV/ 149 | env.bak/ 150 | venv.bak/ 151 | 152 | # Spyder project settings 153 | .spyderproject 154 | .spyproject 155 | 156 | # Rope project settings 157 | .ropeproject 158 | 159 | # mkdocs documentation 160 | /site 161 | 162 | # mypy 163 | .mypy_cache/ 164 | .dmypy.json 165 | dmypy.json 166 | 167 | # Pyre type checker 168 | .pyre/ 169 | 170 | # pytype static type analyzer 171 | .pytype/ 172 | 173 | # Cython debug symbols 174 | cython_debug/ 175 | 176 | # PyCharm 177 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 178 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 179 | # and can be added to the global gitignore or merged into this file. For a more nuclear 180 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 181 | #.idea/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | models/* 190 | LLaDA/models/ 191 | attention_* 192 | token_similarity 193 | token_similarity_adjacent 194 | cache_save/ 195 | token_similarity_adjacent 196 | token 197 | results 198 | grid_images -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dLLM-Cache: Accelerating Diffusion Large Language Models with Adaptive Caching 2 | 3 | Official PyTorch implementation of the paper **["dLLM-Cache: Accelerating Diffusion Large Language Models with Adaptive Caching"](https://www.researchgate.net/publication/392169456_dLLM-Cache_Accelerating_Diffusion_Large_Language_Models_with_Adaptive_Caching)** (dLLM-Cache). 4 | 5 | ## :fire: News 6 | - [2025/05/31] Our dLLM-Cache is integrated into [LLaDA-V](https://github.com/ML-GSAI/LLaDA-V). 7 | - [2025/05/23] The code of our paper has been released. 8 | - [2025/05/22] Our paper has been released. 9 | 10 | ## ✨️ Key Highlights 11 | 12 | ![radar_speed](./asset/radar.png) 13 | - **Currently supported models**: [LLaDA](https://github.com/ML-GSAI/LLaDA), [Dream](https://github.com/HKUNLP/Dream), [LLaDA-V](https://github.com/ML-GSAI/LLaDA-V). 14 | - **Speedup**: Achieves up to **9.1x** speedup over standard dLLM pipelines, with **no performance loss** on most tasks. 15 | - **Evaluation**: Evaluated on [LLaDA 8B](https://arxiv.org/abs/2502.09992) and [Dream 7B](https://hkunlp.github.io/blog/2025/dream/). 16 | - **Latency**: Approaches ARM-level inference speeds in many scenarios. 17 | 18 | ## :rocket: Pipeline 19 | 20 | Here's an overview of the process behind our **dLLM-Cache** method: 21 | ![pipeline](./asset/pipeline.png) 22 | 23 | ## 🛠️ Installation 24 | 25 | To get started with dLLM-Cache, follow the installation instructions below. 26 | 27 | 1. Clone the Repository: 28 | ```sh 29 | git clone https://github.com/maomaocun/dLLM-Cache.git 30 | cd dLLM-Cache 31 | ``` 32 | 33 | 2. Set Up the Environment: 34 | Create a Python environment with `conda` or `virtualenv` and install dependencies: 35 | ```bash 36 | bash install.sh 37 | ``` 38 | 39 | 3. Demo: 40 | 41 | ```bash 42 | python demo_{model_name}.py 43 | ``` 44 | 45 | 4. Running Experiments: 46 | Run experiments using the provided scripts: 47 | 48 | ```bash 49 | bash scripts/run_{model_name}_{task_name}_base.sh 50 | ``` 51 | ### :blue_book: Example Usage 52 | 1. GSM8K with LLaDA 53 | ```bash 54 | bash scripts/run_LLaDA_gsm8k_base.sh 55 | ``` 56 | 57 | 2. BBH with Dream 58 | ```bash 59 | bash scripts/run_Dream_bbh_base.sh 60 | ``` 61 | 62 | 63 | ## :postbox: Contact 64 | If you have any questions, please email [yangyicun187@gmail.com](mailto:yangyicun187@gmail.com). 65 | 66 | 67 | ## 🎉 Acknowledgements 68 | This repository was built off of [LLaDA](https://github.com/ML-GSAI/LLaDA), [Dream](https://github.com/HKUNLP/Dream), [LLaDA-V](https://github.com/ML-GSAI/LLaDA-V) and [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). 69 | 70 | ## :pushpin: Citation 71 | If you find dLLM-Cache useful for your research and applications, please cite using this BibTeX: 72 | 73 | ```bibtex 74 | @misc{liu2025dllm, 75 | title={dLLM-Cache: Accelerating Diffusion Large Language Models with Adaptive Caching}, 76 | author={Zhiyuan Liu and Yicun Yang and Yaojie Zhang and Junjie Chen and Chang Zou and Qingyan Wei and Shaobo Wang and Linfeng Zhang}, 77 | year={2025}, 78 | url={https://github.com/maomaocun/dLLM-cache}, 79 | } 80 | ``` 81 | 82 | ## :star2: Star History 83 | 84 | [![Star History Chart](https://api.star-history.com/svg?repos=maomaocun/dLLM-cache&type=Timeline)](https://www.star-history.com/#maomaocun/dLLM-cache&Timeline) 85 | 86 | -------------------------------------------------------------------------------- /accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | gpu_ids: '0,1,2,3,4,5,6,7' 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: 'no' 10 | num_machines: 1 11 | num_processes: 8 12 | use_cpu: false 13 | main_process_port: 20658 14 | -------------------------------------------------------------------------------- /asset/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maomaocun/dLLM-cache/9a905b0d219c8ab587838cafeb3c813265520b5b/asset/paper.pdf -------------------------------------------------------------------------------- /asset/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maomaocun/dLLM-cache/9a905b0d219c8ab587838cafeb3c813265520b5b/asset/pipeline.png -------------------------------------------------------------------------------- /asset/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maomaocun/dLLM-cache/9a905b0d219c8ab587838cafeb3c813265520b5b/asset/radar.png -------------------------------------------------------------------------------- /asset/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maomaocun/dLLM-cache/9a905b0d219c8ab587838cafeb3c813265520b5b/asset/test.jpg -------------------------------------------------------------------------------- /demo_Dream.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | from dllm_cache.cache import dLLMCache, dLLMCacheConfig 4 | from dllm_cache.hooks import register_cache_Dream, logout_cache_Dream 5 | from dataclasses import asdict 6 | from transformers import AutoModel, AutoTokenizer 7 | import torch 8 | 9 | # Configuration parameters 10 | prompt_interval_steps = 100 11 | gen_interval_steps = 7 12 | transfer_ratio = 0.25 13 | use_cache = True 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | max_new_tokens = 256 16 | steps = 256 17 | max_tokens = 2048 18 | 19 | # Load model and tokenizer 20 | model = ( 21 | AutoModel.from_pretrained( 22 | "Dream-org/Dream-v0-Instruct-7B", 23 | trust_remote_code=True, 24 | torch_dtype=torch.bfloat16, 25 | ) 26 | .to(device) 27 | .eval() 28 | ) 29 | tokenizer = AutoTokenizer.from_pretrained( 30 | "Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True 31 | ) 32 | 33 | # Initialize cache 34 | if use_cache: 35 | dLLMCache.new_instance( 36 | **asdict( 37 | dLLMCacheConfig( 38 | prompt_interval_steps=prompt_interval_steps, 39 | gen_interval_steps=gen_interval_steps, 40 | transfer_ratio=transfer_ratio, 41 | ) 42 | ) 43 | ) 44 | register_cache_Dream(model, "model.layers") 45 | 46 | # Store conversation history 47 | conversation_history = [] 48 | 49 | def format_time(): 50 | """Return current time in formatted string""" 51 | return datetime.now().strftime("%Y-%m-%d %H:%M:%S") 52 | 53 | def truncate_conversation(history, max_tokens): 54 | """Truncate conversation history to ensure total tokens do not exceed max_tokens""" 55 | total_tokens = 0 56 | truncated_history = [] 57 | for msg in reversed(history): 58 | tokens = len(tokenizer(msg["content"])["input_ids"]) 59 | if total_tokens + tokens <= max_tokens: 60 | truncated_history.insert(0, msg) 61 | total_tokens += tokens 62 | else: 63 | break 64 | return truncated_history 65 | 66 | def print_help(): 67 | """Print available commands""" 68 | print("\nAvailable commands:") 69 | print(" : Show this help message") 70 | print(" : Enable cache") 71 | print(" : Disable cache") 72 | print(" : Clear conversation history") 73 | print(" : Exit the program") 74 | print() 75 | 76 | print("*" * 66) 77 | print( 78 | f"** Answer Length: {max_new_tokens} | Sampling Steps: {steps} | Cache Enabled: {use_cache}" 79 | ) 80 | print("*" * 66) 81 | print("Type '' for available commands.") 82 | 83 | while True: 84 | print("\n" + "=" * 70) 85 | user_input = input(f"Enter your question (Cache is {'enable' if use_cache else 'disable'}, Type '' for available commands): ") 86 | 87 | if user_input.lower() == '': 88 | print("Conversation ended.") 89 | break 90 | 91 | if user_input == "": 92 | print_help() 93 | continue 94 | 95 | if user_input == "": 96 | logout_cache_Dream(model, "model.layers") 97 | use_cache = False 98 | print("Cache disabled. Please continue with your question.") 99 | continue 100 | 101 | if user_input == "": 102 | dLLMCache.new_instance( 103 | **asdict( 104 | dLLMCacheConfig( 105 | prompt_interval_steps=prompt_interval_steps, 106 | gen_interval_steps=gen_interval_steps, 107 | transfer_ratio=transfer_ratio, 108 | ) 109 | ) 110 | ) 111 | register_cache_Dream(model, "model.layers") 112 | use_cache = True 113 | print("Cache enabled. Please continue with your question.") 114 | continue 115 | 116 | if user_input == "": 117 | conversation_history = [] 118 | print("Conversation history cleared. Please continue with your question.") 119 | continue 120 | 121 | # Record user input time 122 | input_time = format_time() 123 | conversation_history.append({"role": "user", "content": user_input, "time": input_time}) 124 | 125 | # Truncate conversation history to ensure it does not exceed max token limit 126 | conversation_history = truncate_conversation(conversation_history, max_tokens) 127 | 128 | # Apply chat template 129 | formatted_input = tokenizer.apply_chat_template( 130 | conversation_history, add_generation_prompt=True, tokenize=False 131 | ) 132 | 133 | # Encode input 134 | input_ids = tokenizer(formatted_input)["input_ids"] 135 | attention_mask = tokenizer(formatted_input)["attention_mask"] 136 | input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) 137 | attention_mask = torch.tensor(attention_mask).to(device).unsqueeze(0) 138 | 139 | # Reset cache 140 | feature_cache = dLLMCache() 141 | feature_cache.reset_cache(input_ids.shape[1]) 142 | 143 | # Generate response 144 | start_time = time.time() 145 | generation_ids = model.diffusion_generate( 146 | input_ids, 147 | attention_mask=attention_mask, 148 | max_new_tokens=max_new_tokens, 149 | output_history=False, 150 | return_dict_in_generate=True, 151 | steps=steps, 152 | temperature=0.2, 153 | top_p=0.95, 154 | ).sequences[:, input_ids.shape[1]:] 155 | end_time = time.time() 156 | 157 | # Decode response 158 | answer = tokenizer.batch_decode(generation_ids, skip_special_tokens=True)[0] 159 | reply_time = format_time() 160 | 161 | # Store assistant response 162 | conversation_history.append({"role": "assistant", "content": answer, "time": reply_time}) 163 | 164 | # Print conversation 165 | print(f"Dream ({reply_time}): {answer}") 166 | print(f"Generation Time: {end_time - start_time:.2f} seconds") -------------------------------------------------------------------------------- /demo_LLaDA.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | from dllm_cache.cache import dLLMCache, dLLMCacheConfig 4 | from dllm_cache.hooks import register_cache_LLaDA, logout_cache_LLaDA 5 | from dataclasses import asdict 6 | from transformers import AutoModel, AutoTokenizer 7 | import torch 8 | from utils import generate 9 | 10 | # Configuration parameters 11 | prompt_interval_steps = 100 12 | gen_interval_steps = 7 13 | transfer_ratio = 0.25 14 | use_cache = True 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | gen_length = 256 17 | steps = 256 18 | max_tokens = 2048 19 | 20 | # Load model and tokenizer 21 | model = ( 22 | AutoModel.from_pretrained( 23 | "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16 24 | ) 25 | .to(device) 26 | .eval() 27 | ) 28 | tokenizer = AutoTokenizer.from_pretrained( 29 | "GSAI-ML/LLaDA-8B-Instruct", trust_remote_code=True 30 | ) 31 | 32 | # Initialize cache 33 | if use_cache: 34 | dLLMCache.new_instance( 35 | **asdict( 36 | dLLMCacheConfig( 37 | prompt_interval_steps=prompt_interval_steps, 38 | gen_interval_steps=gen_interval_steps, 39 | transfer_ratio=transfer_ratio, 40 | ) 41 | ) 42 | ) 43 | register_cache_LLaDA(model, "model.transformer.blocks") 44 | 45 | # Store conversation history 46 | conversation_history = [] 47 | 48 | def format_time(): 49 | """Return current time in formatted string""" 50 | return datetime.now().strftime("%Y-%m-%d %H:%M:%S") 51 | 52 | def truncate_conversation(history, max_tokens): 53 | """Truncate conversation history to ensure total tokens do not exceed max_tokens""" 54 | total_tokens = 0 55 | truncated_history = [] 56 | for msg in reversed(history): 57 | tokens = len(tokenizer(msg["content"])["input_ids"]) 58 | if total_tokens + tokens <= max_tokens: 59 | truncated_history.insert(0, msg) 60 | total_tokens += tokens 61 | else: 62 | break 63 | return truncated_history 64 | 65 | def print_help(): 66 | """Print available commands""" 67 | print("\nAvailable commands:") 68 | print(" : Show this help message") 69 | print(" : Enable cache") 70 | print(" : Disable cache") 71 | print(" : Clear conversation history") 72 | print(" : Exit the program") 73 | print() 74 | 75 | print("*" * 66) 76 | print( 77 | f"** Answer Length: {gen_length} | Sampling Steps: {steps} | Cache Enabled: {use_cache}" 78 | ) 79 | print("*" * 66) 80 | print("Type '' for available commands.") 81 | 82 | while True: 83 | print("\n" + "=" * 70) 84 | user_input = input(f"Enter your question (Cache is {'enable' if use_cache else 'disable'}, Type '' for available commands): ") 85 | 86 | 87 | if user_input.lower() == '': 88 | print("Conversation ended.") 89 | break 90 | 91 | if user_input == "": 92 | print_help() 93 | continue 94 | 95 | if user_input == "": 96 | logout_cache_LLaDA(model, "model.transformer.blocks") 97 | use_cache = False 98 | print("Cache disabled. Please continue with your question.") 99 | continue 100 | 101 | if user_input == "": 102 | dLLMCache.new_instance( 103 | **asdict( 104 | dLLMCacheConfig( 105 | prompt_interval_steps=prompt_interval_steps, 106 | gen_interval_steps=gen_interval_steps, 107 | transfer_ratio=transfer_ratio, 108 | ) 109 | ) 110 | ) 111 | register_cache_LLaDA(model, "model.transformer.blocks") 112 | use_cache = True 113 | print("Cache enabled. Please continue with your question.") 114 | continue 115 | 116 | if user_input == "": 117 | conversation_history = [] 118 | print("Conversation history cleared. Please continue with your question.") 119 | continue 120 | 121 | # Record user input time 122 | input_time = format_time() 123 | conversation_history.append({"role": "user", "content": user_input, "time": input_time}) 124 | 125 | # Truncate conversation history to ensure it does not exceed max token limit 126 | conversation_history = truncate_conversation(conversation_history, max_tokens) 127 | 128 | # Apply chat template 129 | formatted_input = tokenizer.apply_chat_template( 130 | conversation_history, add_generation_prompt=True, tokenize=False 131 | ) 132 | 133 | # Encode input 134 | input_ids = tokenizer(formatted_input)["input_ids"] 135 | attention_mask = tokenizer(formatted_input)["attention_mask"] 136 | input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) 137 | attention_mask = torch.tensor(attention_mask).to(device).unsqueeze(0) 138 | 139 | # Reset cache 140 | feature_cache = dLLMCache() 141 | feature_cache.reset_cache(input_ids.shape[1]) 142 | 143 | # Generate response 144 | start_time = time.time() 145 | generation_ids = generate( 146 | input_ids=input_ids, 147 | attention_mask=attention_mask, 148 | model=model, 149 | steps=steps, 150 | gen_length=gen_length, 151 | block_length=steps, 152 | ) 153 | end_time = time.time() 154 | 155 | # Decode response 156 | answer = tokenizer.batch_decode(generation_ids, skip_special_tokens=True)[0] 157 | reply_time = format_time() 158 | 159 | # Store assistant response 160 | conversation_history.append({"role": "assistant", "content": answer, "time": reply_time}) 161 | 162 | # Print conversation 163 | print(f"LLaDA ({reply_time}): {answer}") 164 | print(f"Generation Time: {end_time - start_time:.2f} seconds") -------------------------------------------------------------------------------- /demo_LLaDA_V.py: -------------------------------------------------------------------------------- 1 | from transformers.generation import stopping_criteria 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | 7 | from dllm_cache.cache import dLLMCache, dLLMCacheConfig 8 | from dllm_cache.hooks import register_cache_LLaDA_V 9 | from dataclasses import asdict 10 | 11 | from PIL import Image 12 | import requests 13 | import copy 14 | import torch 15 | import time 16 | 17 | import sys 18 | import warnings 19 | 20 | prompt_interval_steps = 25 21 | gen_interval_steps = 7 22 | transfer_ratio = 0.25 23 | use_cache = True # In this demo, we consider using dLLM-Cache(https://github.com/maomaocun/dLLM-cache) to speed up generation. Set to True to enable caching or False to test without it. 24 | 25 | warnings.filterwarnings("ignore") 26 | pretrained = "GSAI-ML/LLaDA-V" 27 | 28 | model_name = "llava_llada" 29 | device = "cpu" 30 | device_map = "cpu" 31 | tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, attn_implementation="sdpa", device_map=device_map) # Add any other thing you want to pass in llava_model_args 32 | 33 | model.eval() 34 | image = Image.open("./asset/test.jpg") 35 | image_tensor = process_images([image], image_processor, model.config) 36 | image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor] 37 | 38 | conv_template = "llava_llada" 39 | question = DEFAULT_IMAGE_TOKEN + "\nPlease describe the image in detail." 40 | conv = copy.deepcopy(conv_templates[conv_template]) 41 | conv.append_message(conv.roles[0], question) 42 | conv.append_message(conv.roles[1], None) 43 | prompt_question = conv.get_prompt() 44 | 45 | model.eval() 46 | if use_cache: 47 | dLLMCache.new_instance( 48 | **asdict( 49 | dLLMCacheConfig( 50 | prompt_interval_steps=prompt_interval_steps, 51 | gen_interval_steps=gen_interval_steps, 52 | transfer_ratio=transfer_ratio, 53 | ) 54 | ) 55 | ) 56 | register_cache_LLaDA_V(model, "model.layers") 57 | print("Testing with cache enabled") 58 | else: 59 | print("Testing without cache") 60 | 61 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 62 | image_sizes = [image.size] 63 | 64 | start_time = time.time() 65 | cont = model.generate( 66 | input_ids, 67 | images=image_tensor, 68 | image_sizes=image_sizes, 69 | steps=128, gen_length=128, block_length=128, tokenizer=tokenizer, stopping_criteria=['<|eot_id|>'] 70 | ) 71 | end_time = time.time() 72 | generation_time = end_time - start_time 73 | print(f"Generation time: {generation_time:.4f} seconds") 74 | 75 | print(cont) 76 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=False) 77 | print(text_outputs) 78 | -------------------------------------------------------------------------------- /dllm_cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache import Cache 2 | from .hooks import register_cache_LLaDA, logout_cache_LLaDA 3 | from .hooks import register_cache_Dream, logout_cache_Dream 4 | __all__ = ["Cache", "register_cache_LLaDA", "logout_cache_LLaDA","register_cache_Dream", "logout_cache_Dream"] -------------------------------------------------------------------------------- /dllm_cache/cache/Cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | 5 | class Singleton(type): 6 | _instances = {} 7 | 8 | def __call__(cls, *args, **kwargs): 9 | if cls not in cls._instances: 10 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 11 | return cls._instances[cls] 12 | 13 | 14 | class dLLMCache(metaclass=Singleton): 15 | gen_interval_steps: int 16 | prompt_interval_steps: int 17 | cfg_interval_steps: int 18 | prompt_length: int 19 | transfer_ratio: float 20 | __cache: defaultdict 21 | __step_counter: defaultdict 22 | 23 | @classmethod 24 | def new_instance( 25 | cls, 26 | prompt_interval_steps: int = 1, 27 | gen_interval_steps: int = 1, 28 | cfg_interval_steps: int = 1, 29 | transfer_ratio: float = 0.0, 30 | ) -> "dLLMCache": 31 | ins = cls() 32 | setattr(ins, "prompt_interval_steps", prompt_interval_steps) 33 | setattr(ins, "gen_interval_steps", gen_interval_steps) 34 | setattr(ins, "cfg_interval_steps", cfg_interval_steps) 35 | setattr(ins, "transfer_ratio", transfer_ratio) 36 | ins.init() 37 | return ins 38 | 39 | def init(self) -> None: 40 | self.__cache = defaultdict( 41 | lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) 42 | ) 43 | self.__step_counter = defaultdict(lambda: defaultdict(lambda: 0)) 44 | 45 | def reset_cache(self, prompt_length: int = 0) -> None: 46 | self.init() 47 | torch.cuda.empty_cache() 48 | self.prompt_length = prompt_length 49 | self.cache_type = "no_cfg" 50 | 51 | def set_cache( 52 | self, layer_id: int, feature_name: str, features: torch.Tensor, cache_type: str 53 | ) -> None: 54 | self.__cache[self.cache_type][cache_type][layer_id][feature_name] = { 55 | 0: features 56 | } 57 | 58 | def get_cache( 59 | self, layer_id: int, feature_name: str, cache_type: str 60 | ) -> torch.Tensor: 61 | output = self.__cache[self.cache_type][cache_type][layer_id][feature_name][0] 62 | return output 63 | 64 | def update_step(self, layer_id: int) -> None: 65 | self.__step_counter[self.cache_type][layer_id] += 1 66 | 67 | def refresh_gen(self, layer_id: int = 0) -> bool: 68 | return (self.current_step - 1) % self.gen_interval_steps == 0 69 | 70 | def refresh_prompt(self, layer_id: int = 0) -> bool: 71 | return (self.current_step - 1) % self.prompt_interval_steps == 0 72 | 73 | def refresh_cfg(self, layer_id: int = 0) -> bool: 74 | return ( 75 | self.current_step - 1 76 | ) % self.cfg_interval_steps == 0 or self.current_step <= 5 77 | 78 | @property 79 | def current_step(self) -> int: 80 | return max(list(self.__step_counter[self.cache_type].values()), default=1) 81 | 82 | def __repr__(self): 83 | return f"USE dLLMCache" 84 | -------------------------------------------------------------------------------- /dllm_cache/cache/Config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class dLLMCacheConfig: 6 | prompt_interval_steps: int = 1 7 | gen_interval_steps: int = 1 8 | transfer_ratio: float = 0.0 9 | cfg_interval_steps: int = 1 10 | -------------------------------------------------------------------------------- /dllm_cache/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .Cache import dLLMCache 2 | from .Config import dLLMCacheConfig 3 | 4 | __all__ = ["dLLMCache", "dLLMCacheConfig"] 5 | -------------------------------------------------------------------------------- /dllm_cache/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache_hook_LLaDA import register_cache_LLaDA, logout_cache_LLaDA 2 | from .cache_hook_Dream import register_cache_Dream, logout_cache_Dream 3 | from .cache_hook_LLaDA_V import register_cache_LLaDA_V 4 | __all__ = [ 5 | "register_cache_LLaDA", 6 | "logout_cache_LLaDA", 7 | "register_cache_Dream", 8 | "logout_cache_Dream", 9 | "register_cache_LLaDA_V", 10 | "logout_cache_LLaDA_V", 11 | ] 12 | -------------------------------------------------------------------------------- /eval_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .Dream import Dream 2 | from .LLaDA import LLaDA 3 | __all__ = ["Dream", "LLaDA"] -------------------------------------------------------------------------------- /evaluation_script.py: -------------------------------------------------------------------------------- 1 | from eval_model import LLaDA,Dream 2 | from utils import set_seed 3 | import os 4 | from lm_eval.__main__ import cli_evaluate 5 | if __name__ == "__main__": 6 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 7 | os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" 8 | set_seed(1234) 9 | cli_evaluate() -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | conda create -n dllm_cache python=3.12 2 | conda activate dllm_cache 3 | pip install -r requirements.txt 4 | huggingface-cli download --resume-download GSAI-ML/LLaDA-8B-Instruct 5 | huggingface-cli download --resume-download GSAI-ML/LLaDA-8B-Base 6 | huggingface-cli download --resume-download Dream-org/Dream-v0-Instruct-7B 7 | huggingface-cli download --resume-download Dream-org/Dream-v0-Base-7B 8 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maomaocun/dLLM-cache/9a905b0d219c8ab587838cafeb3c813265520b5b/llava/__init__.py -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llada": "LlavaLLaDAModelLM, LlavaLLaDAConfig", 5 | } 6 | 7 | for model_name, model_classes in AVAILABLE_MODELS.items(): 8 | try: 9 | exec(f"from .language_model.{model_name} import {model_classes}") 10 | except Exception as e: 11 | print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}") 12 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model import * 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llava/model/language_model/configuration_llada.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaDA model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLaDA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LLaDAConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LLaDAModel`]. It is used to instantiate an LLaDA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaDA-8B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | 41 | Args: 42 | vocab_size (`int`, *optional*, defaults to 32000): 43 | Vocabulary size of the LLaDA model. Defines the number of different tokens that can be represented by the 44 | `inputs_ids` passed when calling [`LLaDAModel`] 45 | hidden_size (`int`, *optional*, defaults to 4096): 46 | Dimension of the hidden representations. 47 | intermediate_size (`int`, *optional*, defaults to 11008): 48 | Dimension of the MLP representations. 49 | num_hidden_layers (`int`, *optional*, defaults to 32): 50 | Number of hidden layers in the Transformer decoder. 51 | num_attention_heads (`int`, *optional*, defaults to 32): 52 | Number of attention heads for each attention layer in the Transformer decoder. 53 | num_key_value_heads (`int`, *optional*): 54 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 55 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 56 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 57 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 58 | by meanpooling all the original heads within that group. For more details checkout [this 59 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 60 | `num_attention_heads`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to 2048): 64 | The maximum sequence length that this model might ever be used with. 65 | initializer_range (`float`, *optional*, defaults to 0.02): 66 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 67 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 68 | The epsilon used by the rms normalization layers. 69 | use_cache (`bool`, *optional*, defaults to `True`): 70 | Whether or not the model should return the last key/values attentions (not used by all models). Only 71 | relevant if `config.is_decoder=True`. 72 | pad_token_id (`int`, *optional*): 73 | Padding token id. 74 | bos_token_id (`int`, *optional*, defaults to 1): 75 | Beginning of stream token id. 76 | eos_token_id (`int`, *optional*, defaults to 2): 77 | End of stream token id. 78 | pretraining_tp (`int`, *optional*, defaults to 1): 79 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 80 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is 81 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 82 | issue](https://github.com/pytorch/pytorch/issues/76232). 83 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 84 | Whether to tie weight embeddings 85 | rope_theta (`float`, *optional*, defaults to 10000.0): 86 | The base period of the RoPE embeddings. 87 | rope_scaling (`Dict`, *optional*): 88 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 89 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 90 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 91 | `max_position_embeddings` to the expected new maximum. 92 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 93 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 94 | attention_dropout (`float`, *optional*, defaults to 0.0): 95 | The dropout ratio for the attention probabilities. 96 | """ 97 | 98 | model_type = "llada" 99 | keys_to_ignore_at_inference = ["past_key_values"] 100 | 101 | def __init__( 102 | self, 103 | vocab_size=32000, 104 | hidden_size=4096, 105 | intermediate_size=11008, 106 | num_hidden_layers=32, 107 | num_attention_heads=32, 108 | num_key_value_heads=None, 109 | hidden_act="silu", 110 | max_position_embeddings=2048, 111 | initializer_range=0.02, 112 | rms_norm_eps=1e-6, 113 | use_cache=True, 114 | pad_token_id=None, 115 | bos_token_id=1, 116 | eos_token_id=2, 117 | pretraining_tp=1, 118 | tie_word_embeddings=False, 119 | rope_theta=10000.0, 120 | rope_scaling=None, 121 | attention_bias=False, 122 | attention_dropout=0.0, 123 | **kwargs, 124 | ): 125 | self.vocab_size = vocab_size 126 | self.max_position_embeddings = max_position_embeddings 127 | self.hidden_size = hidden_size 128 | self.intermediate_size = intermediate_size 129 | self.num_hidden_layers = num_hidden_layers 130 | self.num_attention_heads = num_attention_heads 131 | 132 | # for backward compatibility 133 | if num_key_value_heads is None: 134 | num_key_value_heads = num_attention_heads 135 | 136 | self.num_key_value_heads = num_key_value_heads 137 | self.hidden_act = hidden_act 138 | self.initializer_range = initializer_range 139 | self.rms_norm_eps = rms_norm_eps 140 | self.pretraining_tp = pretraining_tp 141 | self.use_cache = use_cache 142 | self.rope_theta = rope_theta 143 | self.rope_scaling = rope_scaling 144 | self._rope_scaling_validation() 145 | self.attention_bias = attention_bias 146 | self.attention_dropout = attention_dropout 147 | 148 | super().__init__( 149 | pad_token_id=pad_token_id, 150 | bos_token_id=bos_token_id, 151 | eos_token_id=eos_token_id, 152 | tie_word_embeddings=tie_word_embeddings, 153 | **kwargs, 154 | ) 155 | 156 | def _rope_scaling_validation(self): 157 | """ 158 | Validate the `rope_scaling` configuration. 159 | """ 160 | if self.rope_scaling is None: 161 | return 162 | 163 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 164 | raise ValueError( 165 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 166 | f"got {self.rope_scaling}" 167 | ) 168 | rope_scaling_type = self.rope_scaling.get("type", None) 169 | rope_scaling_factor = self.rope_scaling.get("factor", None) 170 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 171 | raise ValueError( 172 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 173 | ) 174 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 175 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 176 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llada.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Zebin You 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union, Dict 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | import transformers 22 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | 25 | from torch.nn import CrossEntropyLoss 26 | 27 | from transformers.modeling_outputs import CausalLMOutputWithPast 28 | from transformers.generation.utils import GenerateOutput 29 | 30 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 31 | from llava.model.language_model.configuration_llada import LLaDAConfig 32 | from llava.model.language_model.modeling_llada import LLaDAModel, LLaDAModelLM 33 | 34 | 35 | class LlavaLLaDAConfig(LLaDAConfig): 36 | model_type = "llava_llada" 37 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 38 | max_new_tokens: int = 1024 39 | do_sample: bool = False 40 | top_p: Optional[float] = None 41 | # rope_scaling: Optional[dict] = {} 42 | 43 | 44 | class LlavaLLaDAModel(LlavaMetaModel, LLaDAModel): 45 | config_class = LlavaLLaDAConfig 46 | 47 | def __init__(self, config: LLaDAConfig): 48 | super(LlavaLLaDAModel, self).__init__(config) 49 | 50 | 51 | class LlavaLLaDAModelLM(LLaDAModelLM, LlavaMetaForCausalLM): 52 | config_class = LlavaLLaDAConfig 53 | 54 | def __init__(self, config): 55 | LLaDAModelLM.__init__(self, config) 56 | 57 | # configure default generation settings 58 | config.model_type = "llava_llada" 59 | # config.rope_scaling = None 60 | 61 | self.model = LlavaLLaDAModel(config) 62 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 63 | # Initialize weights and apply final processing 64 | self.post_init() 65 | 66 | def get_model(self): 67 | return self.model 68 | 69 | def forward( 70 | self, 71 | input_ids: torch.LongTensor = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | position_ids: Optional[torch.LongTensor] = None, 74 | past_key_values: Optional[List[torch.FloatTensor]] = None, 75 | inputs_embeds: Optional[torch.FloatTensor] = None, 76 | labels: Optional[torch.LongTensor] = None, 77 | use_cache: Optional[bool] = None, 78 | output_attentions: Optional[bool] = None, 79 | output_hidden_states: Optional[bool] = None, 80 | images: Optional[torch.FloatTensor] = None, 81 | image_sizes: Optional[List[List[int]]] = None, 82 | return_dict: Optional[bool] = None, 83 | modalities: Optional[List[str]] = ["image"], 84 | dpo_forward: Optional[bool] = None, 85 | cache_position=None, 86 | ) -> Union[Tuple, CausalLMOutputWithPast]: 87 | 88 | if inputs_embeds is None and attention_mask is not None: 89 | # donate multi-dialogue 90 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, conversation_ids) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, is_llada=True) 91 | elif inputs_embeds is None: 92 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 93 | conversation_ids = None 94 | if dpo_forward: 95 | outputs = self.model( 96 | input_ids=input_ids, 97 | attention_mask=attention_mask, 98 | position_ids=position_ids, 99 | past_key_values=past_key_values, 100 | inputs_embeds=inputs_embeds, 101 | use_cache=use_cache, 102 | output_attentions=output_attentions, 103 | output_hidden_states=output_hidden_states, 104 | return_dict=return_dict, 105 | ) 106 | 107 | hidden_states = outputs[0] 108 | logits = self.lm_head(hidden_states) 109 | return logits, labels 110 | 111 | else: 112 | return super().forward( 113 | input_ids=input_ids, 114 | attention_mask=attention_mask, 115 | position_ids=position_ids, 116 | past_key_values=past_key_values, 117 | inputs_embeds=inputs_embeds, 118 | labels=labels, 119 | use_cache=use_cache, 120 | output_attentions=output_attentions, 121 | output_hidden_states=output_hidden_states, 122 | return_dict=return_dict, 123 | conversation_ids=conversation_ids, 124 | ) 125 | 126 | @torch.no_grad() 127 | def generate( 128 | self, 129 | inputs: Optional[torch.Tensor] = None, 130 | images: Optional[torch.Tensor] = None, 131 | image_sizes: Optional[torch.Tensor] = None, 132 | modalities: Optional[List[str]] = ["image"], 133 | **kwargs, 134 | ) -> Union[GenerateOutput, torch.LongTensor]: 135 | modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities 136 | position_ids = kwargs.pop("position_ids", None) 137 | attention_mask = kwargs.pop("attention_mask", None) 138 | if "inputs_embeds" in kwargs: 139 | raise NotImplementedError("`inputs_embeds` is not supported") 140 | 141 | if images is not None: 142 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 143 | else: 144 | inputs_embeds = self.get_model().embed_tokens(inputs) 145 | 146 | return super().generate_with_embeds(inputs_embeds=inputs_embeds, **kwargs) 147 | 148 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 149 | images = kwargs.pop("images", None) 150 | image_sizes = kwargs.pop("image_sizes", None) 151 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 152 | if images is not None: 153 | inputs["images"] = images 154 | if image_sizes is not None: 155 | inputs["image_sizes"] = image_sizes 156 | return inputs 157 | 158 | 159 | AutoConfig.register("llava_llada", LlavaLLaDAConfig) 160 | AutoModelForCausalLM.register(LlavaLLaDAConfig, LlavaLLaDAModelLM) 161 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: # is_absolute_path_exists or 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 34 | 35 | raise ValueError(f"Unknown vision tower: {vision_tower}") 36 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from llava.utils import rank0_print 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | try: 7 | from s2wrapper import forward as multiscale_forward 8 | except: 9 | pass 10 | 11 | 12 | class CLIPVisionTower(nn.Module): 13 | def __init__(self, vision_tower, args, delay_load=False): 14 | super().__init__() 15 | 16 | self.is_loaded = False 17 | 18 | self.vision_tower_name = vision_tower 19 | self.select_layer = args.mm_vision_select_layer 20 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 21 | 22 | if not delay_load: 23 | rank0_print(f"Loading vision tower: {vision_tower}") 24 | self.load_model() 25 | elif getattr(args, "unfreeze_mm_vision_tower", False): 26 | # TODO: better detector is needed. 27 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 28 | self.load_model() 29 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 30 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 31 | self.load_model() 32 | else: 33 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 34 | 35 | def load_model(self, device_map=None): 36 | if self.is_loaded: 37 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) 38 | return 39 | 40 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 41 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 42 | self.vision_tower.requires_grad_(False) 43 | 44 | self.is_loaded = True 45 | 46 | def feature_select(self, image_forward_outs): 47 | select_feature_type = self.select_feature 48 | 49 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 50 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 51 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 52 | select_feature_type = select_feature_type.replace("slicefour_", "") 53 | elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 54 | select_layers = [-2, -5, -8, -11, 6] 55 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) 56 | select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 57 | else: 58 | image_features = image_forward_outs.hidden_states[self.select_layer] 59 | 60 | if select_feature_type == "patch": 61 | image_features = image_features[:, 1:] 62 | elif select_feature_type == "cls_patch": 63 | image_features = image_features 64 | else: 65 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 66 | return image_features 67 | 68 | def forward(self, images): 69 | if type(images) is list: 70 | image_features = [] 71 | for image in images: 72 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 73 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 74 | image_features.append(image_feature) 75 | else: 76 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 77 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 78 | 79 | return image_features 80 | 81 | @property 82 | def dummy_feature(self): 83 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 84 | 85 | @property 86 | def dtype(self): 87 | return self.vision_tower.dtype 88 | 89 | @property 90 | def device(self): 91 | return self.vision_tower.device 92 | 93 | @property 94 | def config(self): 95 | if self.is_loaded: 96 | return self.vision_tower.config 97 | else: 98 | return self.cfg_only 99 | 100 | @property 101 | def hidden_size(self): 102 | _hidden_size = self.config.hidden_size 103 | if "slicefour" in self.select_feature: 104 | _hidden_size *= 4 105 | if "slice_m25811_f6" in self.select_feature: 106 | _hidden_size *= 5 107 | return _hidden_size 108 | 109 | @property 110 | def num_patches_per_side(self): 111 | return self.config.image_size // self.config.patch_size 112 | 113 | @property 114 | def num_patches(self): 115 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 116 | if "cls_patch" in self.select_feature: 117 | _num_patches += 1 118 | return _num_patches 119 | 120 | @property 121 | def image_size(self): 122 | return self.config.image_size 123 | 124 | 125 | class CLIPVisionTowerS2(CLIPVisionTower): 126 | def __init__(self, vision_tower, args, delay_load=False): 127 | 128 | self.s2_scales = getattr(args, "s2_scales", "336,672,1008") 129 | self.s2_scales = list(map(int, self.s2_scales.split(","))) 130 | self.s2_scales.sort() 131 | self.s2_split_size = self.s2_scales[0] 132 | self.s2_image_size = self.s2_scales[-1] 133 | 134 | super().__init__(vision_tower, args, delay_load) 135 | 136 | # change resize/crop size in preprocessing to the largest image size in s2_scale 137 | if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): 138 | self.image_processor.size["shortest_edge"] = self.s2_image_size 139 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size 140 | 141 | def load_model(self, device_map=None): 142 | if self.is_loaded: 143 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) 144 | return 145 | 146 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 147 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 148 | self.vision_tower.requires_grad_(False) 149 | 150 | self.image_processor.size["shortest_edge"] = self.s2_image_size 151 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size 152 | 153 | self.is_loaded = True 154 | 155 | def forward_feature(self, images): 156 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 157 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 158 | return image_features 159 | 160 | def forward(self, images): 161 | if type(images) is list: 162 | image_features = [] 163 | for image in images: 164 | image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) 165 | image_features.append(image_feature) 166 | else: 167 | image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) 168 | 169 | return image_features 170 | 171 | @property 172 | def hidden_size(self): 173 | return self.config.hidden_size * len(self.s2_scales) 174 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/hf_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor 5 | from llava.utils import rank0_print 6 | 7 | 8 | class HFVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1) 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | try: 25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) 26 | except Exception as e: 27 | if "448" in self.vision_tower_name: 28 | image_size = 448 29 | # use image processor with conig 30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) 31 | else: 32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | rank0_print(f"Loaded image processor: {self.image_processor}") 34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") 35 | self.device = self.vision_tower.device 36 | self.dtype = self.vision_tower.dtype 37 | self.config = self.vision_tower.config 38 | 39 | if hasattr(self.vision_tower, "vision_model"): 40 | self.vision_tower = self.vision_tower.vision_model 41 | self.vision_tower.requires_grad_(False) 42 | # self.vision_tower.eval() 43 | self.is_loaded = True 44 | 45 | def feature_select(self, image_forward_outs): 46 | select_feature_type = self.select_feature 47 | 48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 51 | select_feature_type = select_feature_type.replace("slicefour_", "") 52 | else: 53 | image_features = image_forward_outs.hidden_states[self.select_layer] 54 | 55 | if select_feature_type == "patch": 56 | image_features = image_features[:, 1:] 57 | elif select_feature_type == "cls_patch": 58 | image_features = image_features 59 | else: 60 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 61 | return image_features 62 | 63 | def forward(self, images): 64 | if type(images) is list: 65 | image_features = [] 66 | for image in images: 67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 68 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 69 | image_features.append(image_feature) 70 | else: 71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 72 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 73 | 74 | return image_features 75 | 76 | @property 77 | def dummy_feature(self): 78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 79 | 80 | # @property 81 | # def dtype(self): 82 | # return self.vision_tower.dtype 83 | 84 | # @property 85 | # def device(self): 86 | # return self.vision_tower.device 87 | 88 | @property 89 | def hidden_size(self): 90 | try: 91 | _hidden_size = self.config.hidden_size 92 | except: 93 | _hidden_size = self.config.vision_config.hidden_size 94 | if "slicefour" in self.select_feature: 95 | _hidden_size *= 4 96 | return _hidden_size 97 | 98 | @property 99 | def num_patches(self): 100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 101 | if "cls_patch" in self.select_feature: 102 | _num_patches += 1 103 | return _num_patches 104 | 105 | @property 106 | def num_patches_per_side(self): 107 | return self.config.image_size // self.config.patch_size 108 | 109 | @property 110 | def image_size(self): 111 | return self.config.image_size 112 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/open_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor 4 | from llava.utils import rank0_print 5 | 6 | try: 7 | import open_clip 8 | import torchvision 9 | from open_clip.transformer import _expand_token 10 | except ImportError: 11 | print("OpenCLIP not installed") 12 | open_clip = None 13 | 14 | HIDDEN_SIZE_DICT = { 15 | "ViT-H-14-378-quickgelu": 1280, 16 | } 17 | 18 | 19 | class OpenCLIPVisionTower(nn.Module): 20 | def __init__(self, vision_tower, args, delay_load=False): 21 | super().__init__() 22 | 23 | self.is_loaded = False 24 | self.model_name = vision_tower.replace("open_clip_hub:", "") 25 | self.pretrained = args.vision_tower_pretrained 26 | self.select_layer = args.mm_vision_select_layer 27 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 28 | 29 | if not delay_load: 30 | rank0_print(f"Loading vision tower: {vision_tower}") 31 | self.load_model() 32 | elif getattr(args, "unfreeze_mm_vision_tower", False): 33 | # TODO: better detector is needed. 34 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 35 | self.load_model() 36 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 37 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 38 | self.load_model() 39 | 40 | def load_model(self, device_map="auto"): 41 | rank0_print(f"Loading OpenCLIP model: {self.model_name}") 42 | rank0_print(f"Pretrained: {self.pretrained}") 43 | vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") 44 | 45 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 46 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 47 | self.resize_transform_size = resize_transform.size # 224 or 384 48 | self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 49 | 50 | self.image_processor = CLIPImageProcessor.from_pretrained( 51 | "openai/clip-vit-large-patch14", 52 | crop_size=resize_transform.size, 53 | size={"shortest_edge": resize_transform.size}, 54 | image_mean=list(normalize_transform.mean), 55 | image_std=list(normalize_transform.std), 56 | ) 57 | rank0_print(f"Loaded image processor: {self.image_processor}") 58 | self.vision_tower = vision_tower.visual 59 | self.vision_tower.requires_grad_(False) 60 | 61 | self.is_loaded = True 62 | 63 | def feature_select(self, image_forward_outs): 64 | image_features = image_forward_outs[self.select_layer] 65 | if self.select_feature == "patch": 66 | image_features = image_features[:, 1:] 67 | elif self.select_feature == "cls_patch": 68 | image_features = image_features 69 | elif self.select_feature == "conv_flatten": 70 | image_features = image_features.flatten(2).transpose(1, 2) 71 | else: 72 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 73 | return image_features 74 | 75 | def forward_visual(self, x, output_hidden_states=False): 76 | if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): 77 | return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) 78 | else: 79 | 80 | def forward_openclip(self, x: torch.Tensor): 81 | features = [] 82 | x = self.conv1(x) # shape = [*, width, grid, grid] 83 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 84 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 85 | 86 | # class embeddings and positional embeddings 87 | x = torch.cat( 88 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], 89 | dim=1, 90 | ) 91 | # shape = [*, grid ** 2 + 1, width] 92 | x = x + self.positional_embedding.to(x.dtype) 93 | 94 | x = self.patch_dropout(x) 95 | x = self.ln_pre(x) 96 | 97 | x = x.permute(1, 0, 2) # NLD -> LND 98 | for r in self.transformer.resblocks: 99 | x = r(x, attn_mask=None) 100 | features.append(x) 101 | return features 102 | 103 | return forward_openclip(self.vision_tower, x) 104 | 105 | def forward(self, images): 106 | if type(images) is list: 107 | image_features = [] 108 | for image in images: 109 | image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) 110 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 111 | image_features.append(image_feature) 112 | else: 113 | image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) 114 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 115 | 116 | return image_features 117 | 118 | @property 119 | def dummy_feature(self): 120 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 121 | 122 | @property 123 | def dtype(self): 124 | if hasattr(self.vision_tower, "conv1"): 125 | return self.vision_tower.conv1.weight.dtype 126 | if hasattr(self.vision_tower, "trunk"): 127 | return self.vision_tower.trunk.patch_embed.proj.weight.dtype 128 | raise NotImplementedError 129 | 130 | @property 131 | def device(self): 132 | if hasattr(self.vision_tower, "conv1"): 133 | return self.vision_tower.conv1.weight.device 134 | if hasattr(self.vision_tower, "trunk"): 135 | return self.vision_tower.trunk.patch_embed.proj.weight.device 136 | raise NotImplementedError 137 | 138 | @property 139 | def config(self): 140 | return None 141 | 142 | @property 143 | def hidden_size(self): 144 | if self.model_name in HIDDEN_SIZE_DICT: 145 | return HIDDEN_SIZE_DICT[self.model_name] 146 | else: 147 | raise NotImplementedError 148 | 149 | @property 150 | def num_patches(self): 151 | image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] 152 | _num_patches = (image_size // self.patch_size) ** 2 153 | if "cls_patch" in self.select_feature: 154 | _num_patches += 1 155 | return _num_patches 156 | 157 | @property 158 | def image_size(self): 159 | return self.resize_transform_size 160 | 161 | @property 162 | def num_patches_per_side(self): 163 | return self.resize_transform_size // self.patch_size 164 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maomaocun/dLLM-cache/9a905b0d219c8ab587838cafeb3c813265520b5b/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith("http") or image_file.startswith("https"): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert("RGB") 22 | else: 23 | image = Image.open(image_file).convert("RGB") 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ("user", "assistant") 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda() 56 | 57 | while True: 58 | try: 59 | inp = input(f"{roles[0]}: ") 60 | except EOFError: 61 | inp = "" 62 | if not inp: 63 | print("exit...") 64 | break 65 | 66 | print(f"{roles[1]}: ", end="") 67 | 68 | if image is not None: 69 | # first message 70 | if model.config.mm_use_im_start_end: 71 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp 72 | else: 73 | inp = DEFAULT_IMAGE_TOKEN + "\n" + inp 74 | conv.append_message(conv.roles[0], inp) 75 | image = None 76 | else: 77 | # later messages 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 83 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 84 | keywords = [stop_str] 85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 87 | 88 | with torch.inference_mode(): 89 | output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) 90 | 91 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip() 92 | conv.messages[-1][-1] = outputs 93 | 94 | if args.debug: 95 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 101 | parser.add_argument("--model-base", type=str, default=None) 102 | parser.add_argument("--image-file", type=str, required=True) 103 | parser.add_argument("--num-gpus", type=int, default=1) 104 | parser.add_argument("--conv-mode", type=str, default=None) 105 | parser.add_argument("--temperature", type=float, default=0.2) 106 | parser.add_argument("--max-new-tokens", type=int, default=512) 107 | parser.add_argument("--load-8bit", action="store_true") 108 | parser.add_argument("--load-4bit", action="store_true") 109 | parser.add_argument("--debug", action="store_true") 110 | args = parser.parse_args() 111 | main(args) 112 | -------------------------------------------------------------------------------- /llava/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | 6 | import argparse 7 | import asyncio 8 | import dataclasses 9 | from enum import Enum, auto 10 | import json 11 | import logging 12 | import time 13 | from typing import List, Union 14 | import threading 15 | 16 | from fastapi import FastAPI, Request 17 | from fastapi.responses import StreamingResponse 18 | import numpy as np 19 | import requests 20 | import uvicorn 21 | 22 | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION 23 | from llava.utils import build_logger, server_error_msg 24 | 25 | 26 | logger = build_logger("controller", "controller.log") 27 | 28 | 29 | class DispatchMethod(Enum): 30 | LOTTERY = auto() 31 | SHORTEST_QUEUE = auto() 32 | 33 | @classmethod 34 | def from_str(cls, name): 35 | if name == "lottery": 36 | return cls.LOTTERY 37 | elif name == "shortest_queue": 38 | return cls.SHORTEST_QUEUE 39 | else: 40 | raise ValueError(f"Invalid dispatch method") 41 | 42 | 43 | @dataclasses.dataclass 44 | class WorkerInfo: 45 | model_names: List[str] 46 | speed: int 47 | queue_length: int 48 | check_heart_beat: bool 49 | last_heart_beat: str 50 | 51 | 52 | def heart_beat_controller(controller): 53 | while True: 54 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 55 | controller.remove_stable_workers_by_expiration() 56 | 57 | 58 | class Controller: 59 | def __init__(self, dispatch_method: str): 60 | # Dict[str -> WorkerInfo] 61 | self.worker_info = {} 62 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 63 | 64 | self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self,)) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict): 70 | if worker_name not in self.worker_info: 71 | logger.info(f"Register a new worker: {worker_name}") 72 | else: 73 | logger.info(f"Register an existing worker: {worker_name}") 74 | 75 | if not worker_status: 76 | worker_status = self.get_worker_status(worker_name) 77 | if not worker_status: 78 | return False 79 | 80 | self.worker_info[worker_name] = WorkerInfo(worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], check_heart_beat, time.time()) 81 | 82 | logger.info(f"Register done: {worker_name}, {worker_status}") 83 | return True 84 | 85 | def get_worker_status(self, worker_name: str): 86 | try: 87 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 88 | except requests.exceptions.RequestException as e: 89 | logger.error(f"Get status fails: {worker_name}, {e}") 90 | return None 91 | 92 | if r.status_code != 200: 93 | logger.error(f"Get status fails: {worker_name}, {r}") 94 | return None 95 | 96 | return r.json() 97 | 98 | def remove_worker(self, worker_name: str): 99 | del self.worker_info[worker_name] 100 | 101 | def refresh_all_workers(self): 102 | old_info = dict(self.worker_info) 103 | self.worker_info = {} 104 | 105 | for w_name, w_info in old_info.items(): 106 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 107 | logger.info(f"Remove stale worker: {w_name}") 108 | 109 | def list_models(self): 110 | model_names = set() 111 | 112 | for w_name, w_info in self.worker_info.items(): 113 | model_names.update(w_info.model_names) 114 | 115 | return list(model_names) 116 | 117 | def get_worker_address(self, model_name: str): 118 | if self.dispatch_method == DispatchMethod.LOTTERY: 119 | worker_names = [] 120 | worker_speeds = [] 121 | for w_name, w_info in self.worker_info.items(): 122 | if model_name in w_info.model_names: 123 | worker_names.append(w_name) 124 | worker_speeds.append(w_info.speed) 125 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 126 | norm = np.sum(worker_speeds) 127 | if norm < 1e-4: 128 | return "" 129 | worker_speeds = worker_speeds / norm 130 | if True: # Directly return address 131 | pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) 132 | worker_name = worker_names[pt] 133 | return worker_name 134 | 135 | # Check status before returning 136 | while True: 137 | pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) 138 | worker_name = worker_names[pt] 139 | 140 | if self.get_worker_status(worker_name): 141 | break 142 | else: 143 | self.remove_worker(worker_name) 144 | worker_speeds[pt] = 0 145 | norm = np.sum(worker_speeds) 146 | if norm < 1e-4: 147 | return "" 148 | worker_speeds = worker_speeds / norm 149 | continue 150 | return worker_name 151 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 152 | worker_names = [] 153 | worker_qlen = [] 154 | for w_name, w_info in self.worker_info.items(): 155 | if model_name in w_info.model_names: 156 | worker_names.append(w_name) 157 | worker_qlen.append(w_info.queue_length / w_info.speed) 158 | if len(worker_names) == 0: 159 | return "" 160 | min_index = np.argmin(worker_qlen) 161 | w_name = worker_names[min_index] 162 | self.worker_info[w_name].queue_length += 1 163 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 164 | return w_name 165 | else: 166 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 167 | 168 | def receive_heart_beat(self, worker_name: str, queue_length: int): 169 | if worker_name not in self.worker_info: 170 | logger.info(f"Receive unknown heart beat. {worker_name}") 171 | return False 172 | 173 | self.worker_info[worker_name].queue_length = queue_length 174 | self.worker_info[worker_name].last_heart_beat = time.time() 175 | logger.info(f"Receive heart beat. {worker_name}") 176 | return True 177 | 178 | def remove_stable_workers_by_expiration(self): 179 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 180 | to_delete = [] 181 | for worker_name, w_info in self.worker_info.items(): 182 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 183 | to_delete.append(worker_name) 184 | 185 | for worker_name in to_delete: 186 | self.remove_worker(worker_name) 187 | 188 | def worker_api_generate_stream(self, params): 189 | worker_addr = self.get_worker_address(params["model"]) 190 | if not worker_addr: 191 | logger.info(f"no worker: {params['model']}") 192 | ret = { 193 | "text": server_error_msg, 194 | "error_code": 2, 195 | } 196 | yield json.dumps(ret).encode() + b"\0" 197 | 198 | try: 199 | response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5) 200 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 201 | if chunk: 202 | yield chunk + b"\0" 203 | except requests.exceptions.RequestException as e: 204 | logger.info(f"worker timeout: {worker_addr}") 205 | ret = { 206 | "text": server_error_msg, 207 | "error_code": 3, 208 | } 209 | yield json.dumps(ret).encode() + b"\0" 210 | 211 | # Let the controller act as a worker to achieve hierarchical 212 | # management. This can be used to connect isolated sub networks. 213 | def worker_api_get_status(self): 214 | model_names = set() 215 | speed = 0 216 | queue_length = 0 217 | 218 | for w_name in self.worker_info: 219 | worker_status = self.get_worker_status(w_name) 220 | if worker_status is not None: 221 | model_names.update(worker_status["model_names"]) 222 | speed += worker_status["speed"] 223 | queue_length += worker_status["queue_length"] 224 | 225 | return { 226 | "model_names": list(model_names), 227 | "speed": speed, 228 | "queue_length": queue_length, 229 | } 230 | 231 | 232 | app = FastAPI() 233 | 234 | 235 | @app.post("/register_worker") 236 | async def register_worker(request: Request): 237 | data = await request.json() 238 | controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)) 239 | 240 | 241 | @app.post("/refresh_all_workers") 242 | async def refresh_all_workers(): 243 | models = controller.refresh_all_workers() 244 | 245 | 246 | @app.post("/list_models") 247 | async def list_models(): 248 | models = controller.list_models() 249 | return {"models": models} 250 | 251 | 252 | @app.post("/get_worker_address") 253 | async def get_worker_address(request: Request): 254 | data = await request.json() 255 | addr = controller.get_worker_address(data["model"]) 256 | return {"address": addr} 257 | 258 | 259 | @app.post("/receive_heart_beat") 260 | async def receive_heart_beat(request: Request): 261 | data = await request.json() 262 | exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) 263 | return {"exist": exist} 264 | 265 | 266 | @app.post("/worker_generate_stream") 267 | async def worker_api_generate_stream(request: Request): 268 | params = await request.json() 269 | generator = controller.worker_api_generate_stream(params) 270 | return StreamingResponse(generator) 271 | 272 | 273 | @app.post("/worker_get_status") 274 | async def worker_api_get_status(request: Request): 275 | return controller.worker_api_get_status() 276 | 277 | 278 | if __name__ == "__main__": 279 | parser = argparse.ArgumentParser() 280 | parser.add_argument("--host", type=str, default="localhost") 281 | parser.add_argument("--port", type=int, default=21001) 282 | parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue") 283 | args = parser.parse_args() 284 | logger.info(f"args: {args}") 285 | 286 | controller = Controller(args.dispatch_method) 287 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 288 | -------------------------------------------------------------------------------- /llava/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | 5 | import argparse 6 | import asyncio 7 | import json 8 | import time 9 | import threading 10 | import uuid 11 | 12 | from fastapi import FastAPI, Request, BackgroundTasks 13 | from fastapi.responses import StreamingResponse 14 | import requests 15 | import torch 16 | import uvicorn 17 | from functools import partial 18 | 19 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 20 | from llava.utils import build_logger, server_error_msg, pretty_print_semaphore 21 | from llava.model.builder import load_pretrained_model 22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria 23 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | from transformers import TextIteratorStreamer 25 | from threading import Thread 26 | 27 | 28 | GB = 1 << 30 29 | 30 | worker_id = str(uuid.uuid4())[:6] 31 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 32 | global_counter = 0 33 | 34 | model_semaphore = None 35 | 36 | 37 | def heart_beat_worker(controller): 38 | 39 | while True: 40 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 41 | controller.send_heart_beat() 42 | 43 | 44 | class ModelWorker: 45 | def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, load_8bit, load_4bit): 46 | self.controller_addr = controller_addr 47 | self.worker_addr = worker_addr 48 | self.worker_id = worker_id 49 | if model_path.endswith("/"): 50 | model_path = model_path[:-1] 51 | if model_name is None: 52 | model_paths = model_path.split("/") 53 | if model_paths[-1].startswith("checkpoint-"): 54 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 55 | else: 56 | self.model_name = model_paths[-1] 57 | else: 58 | self.model_name = model_name 59 | 60 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") 61 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, self.model_name, load_8bit, load_4bit) 62 | self.is_multimodal = "llava" in self.model_name.lower() 63 | 64 | if not no_register: 65 | self.register_to_controller() 66 | self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,)) 67 | self.heart_beat_thread.start() 68 | 69 | def register_to_controller(self): 70 | logger.info("Register to controller") 71 | 72 | url = self.controller_addr + "/register_worker" 73 | data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()} 74 | r = requests.post(url, json=data) 75 | assert r.status_code == 200 76 | 77 | def send_heart_beat(self): 78 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") 79 | 80 | url = self.controller_addr + "/receive_heart_beat" 81 | 82 | while True: 83 | try: 84 | ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) 85 | exist = ret.json()["exist"] 86 | break 87 | except requests.exceptions.RequestException as e: 88 | logger.error(f"heart beat error: {e}") 89 | time.sleep(5) 90 | 91 | if not exist: 92 | self.register_to_controller() 93 | 94 | def get_queue_length(self): 95 | if model_semaphore is None: 96 | return 0 97 | else: 98 | return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 99 | 100 | def get_status(self): 101 | return { 102 | "model_names": [self.model_name], 103 | "speed": 1, 104 | "queue_length": self.get_queue_length(), 105 | } 106 | 107 | @torch.inference_mode() 108 | def generate_stream(self, params): 109 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 110 | 111 | prompt = params["prompt"] 112 | ori_prompt = prompt 113 | images = params.get("images", None) 114 | num_image_tokens = 0 115 | if images is not None and len(images) > 0 and self.is_multimodal: 116 | if len(images) > 0: 117 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 118 | raise ValueError("Number of images does not match number of tokens in prompt") 119 | 120 | images = [load_image_from_base64(image) for image in images] 121 | image_sizes = [image.size for image in images] 122 | images = process_images(images, image_processor, model.config) 123 | 124 | if type(images) is list: 125 | images = [image.to(self.model.device, dtype=torch.float16) for image in images] 126 | else: 127 | images = images.to(self.model.device, dtype=torch.float16) 128 | 129 | replace_token = DEFAULT_IMAGE_TOKEN 130 | if getattr(self.model.config, "mm_use_im_start_end", False): 131 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 132 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 133 | 134 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches 135 | else: 136 | images = None 137 | image_sizes = None 138 | image_args = {"images": images, "image_sizes": image_sizes} 139 | else: 140 | images = None 141 | image_args = {} 142 | 143 | temperature = float(params.get("temperature", 1.0)) 144 | top_p = float(params.get("top_p", 1.0)) 145 | max_context_length = getattr(model.config, "max_position_embeddings", 2048) 146 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 147 | stop_str = params.get("stop", None) 148 | do_sample = True if temperature > 0.001 else False 149 | 150 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 151 | keywords = [stop_str] 152 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 153 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 154 | 155 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 156 | 157 | if max_new_tokens < 1: 158 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 159 | return 160 | 161 | thread = Thread( 162 | target=model.generate, 163 | kwargs=dict( 164 | inputs=input_ids, 165 | do_sample=do_sample, 166 | temperature=temperature, 167 | top_p=top_p, 168 | max_new_tokens=max_new_tokens, 169 | streamer=streamer, 170 | # stopping_criteria=[stopping_criteria], 171 | use_cache=True, 172 | **image_args, 173 | ), 174 | ) 175 | thread.start() 176 | 177 | start_time = time.time() 178 | generated_text = ori_prompt 179 | for new_text in streamer: 180 | generated_text += new_text 181 | if generated_text.endswith(stop_str): 182 | generated_text = generated_text[: -len(stop_str)] 183 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 184 | 185 | end_time = time.time() 186 | 187 | new_generated = generated_text[len(ori_prompt) :] 188 | new_generated_tokens = tokenizer(new_generated).input_ids 189 | token_per_second = len(new_generated_tokens) / (end_time - start_time) 190 | print(f"token_per_second: {token_per_second}") 191 | 192 | def generate_stream_gate(self, params): 193 | try: 194 | for x in self.generate_stream(params): 195 | yield x 196 | except ValueError as e: 197 | print("Caught ValueError:", e) 198 | ret = { 199 | "text": server_error_msg, 200 | "error_code": 1, 201 | } 202 | yield json.dumps(ret).encode() + b"\0" 203 | except torch.cuda.CudaError as e: 204 | print("Caught torch.cuda.CudaError:", e) 205 | ret = { 206 | "text": server_error_msg, 207 | "error_code": 1, 208 | } 209 | yield json.dumps(ret).encode() + b"\0" 210 | except Exception as e: 211 | print("Caught Unknown Error", e) 212 | ret = { 213 | "text": server_error_msg, 214 | "error_code": 1, 215 | } 216 | yield json.dumps(ret).encode() + b"\0" 217 | 218 | 219 | app = FastAPI() 220 | 221 | 222 | def release_model_semaphore(fn=None): 223 | model_semaphore.release() 224 | if fn is not None: 225 | fn() 226 | 227 | 228 | @app.post("/worker_generate_stream") 229 | async def generate_stream(request: Request): 230 | global model_semaphore, global_counter 231 | global_counter += 1 232 | params = await request.json() 233 | 234 | if model_semaphore is None: 235 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 236 | await model_semaphore.acquire() 237 | worker.send_heart_beat() 238 | generator = worker.generate_stream_gate(params) 239 | background_tasks = BackgroundTasks() 240 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 241 | return StreamingResponse(generator, background=background_tasks) 242 | 243 | 244 | @app.post("/worker_get_status") 245 | async def get_status(request: Request): 246 | return worker.get_status() 247 | 248 | 249 | if __name__ == "__main__": 250 | parser = argparse.ArgumentParser() 251 | parser.add_argument("--host", type=str, default="localhost") 252 | parser.add_argument("--port", type=int, default=21002) 253 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002") 254 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 255 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 256 | parser.add_argument("--model-base", type=str, default=None) 257 | parser.add_argument("--model-name", type=str) 258 | parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 259 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 260 | parser.add_argument("--stream-interval", type=int, default=1) 261 | parser.add_argument("--no-register", action="store_true") 262 | parser.add_argument("--load-8bit", action="store_true") 263 | parser.add_argument("--load-4bit", action="store_true") 264 | args = parser.parse_args() 265 | logger.info(f"args: {args}") 266 | 267 | if args.multi_modal: 268 | logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 269 | 270 | worker = ModelWorker(args.controller_address, args.worker_address, worker_id, args.no_register, args.model_path, args.model_base, args.model_name, args.load_8bit, args.load_4bit) 271 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 272 | -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/sglang_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | 5 | import argparse 6 | import asyncio 7 | from concurrent.futures import ThreadPoolExecutor 8 | import json 9 | import time 10 | import threading 11 | import uuid 12 | 13 | from fastapi import FastAPI, Request, BackgroundTasks 14 | from fastapi.responses import StreamingResponse 15 | import requests 16 | import re 17 | import uvicorn 18 | from functools import partial 19 | 20 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 21 | from llava.utils import build_logger, server_error_msg, pretty_print_semaphore 22 | from llava.model.builder import load_pretrained_model 23 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square 24 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | from transformers import AutoTokenizer 26 | 27 | import sglang as sgl 28 | from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend 29 | from sglang.backend.runtime_endpoint import RuntimeEndpoint 30 | from sglang.utils import read_jsonl, dump_state_text 31 | from sglang.lang.interpreter import ProgramState 32 | 33 | 34 | GB = 1 << 30 35 | 36 | worker_id = str(uuid.uuid4())[:6] 37 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 38 | global_counter = 0 39 | 40 | model_semaphore = None 41 | 42 | 43 | def heart_beat_worker(controller): 44 | while True: 45 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 46 | controller.send_heart_beat() 47 | 48 | 49 | @sgl.function 50 | def pipeline(s, prompt, max_tokens): 51 | for p in prompt: 52 | if type(p) is str: 53 | s += p 54 | else: 55 | s += sgl.image(p) 56 | s += sgl.gen("response", max_tokens=max_tokens) 57 | 58 | 59 | class ModelWorker: 60 | def __init__(self, controller_addr, worker_addr, sgl_endpoint, worker_id, no_register, model_name): 61 | self.controller_addr = controller_addr 62 | self.worker_addr = worker_addr 63 | self.worker_id = worker_id 64 | 65 | # Select backend 66 | backend = RuntimeEndpoint(sgl_endpoint) 67 | sgl.set_default_backend(backend) 68 | model_path = backend.model_info["model_path"] 69 | 70 | if model_path.endswith("/"): 71 | model_path = model_path[:-1] 72 | if model_name is None: 73 | model_paths = model_path.split("/") 74 | if model_paths[-1].startswith("checkpoint-"): 75 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 76 | else: 77 | self.model_name = model_paths[-1] 78 | else: 79 | self.model_name = model_name 80 | 81 | logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...") 82 | 83 | if not no_register: 84 | self.register_to_controller() 85 | self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,)) 86 | self.heart_beat_thread.start() 87 | 88 | def register_to_controller(self): 89 | logger.info("Register to controller") 90 | 91 | url = self.controller_addr + "/register_worker" 92 | data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()} 93 | r = requests.post(url, json=data) 94 | assert r.status_code == 200 95 | 96 | def send_heart_beat(self): 97 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") 98 | 99 | url = self.controller_addr + "/receive_heart_beat" 100 | 101 | while True: 102 | try: 103 | ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) 104 | exist = ret.json()["exist"] 105 | break 106 | except requests.exceptions.RequestException as e: 107 | logger.error(f"heart beat error: {e}") 108 | time.sleep(5) 109 | 110 | if not exist: 111 | self.register_to_controller() 112 | 113 | def get_queue_length(self): 114 | if model_semaphore is None: 115 | return 0 116 | else: 117 | return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 118 | 119 | def get_status(self): 120 | return { 121 | "model_names": [self.model_name], 122 | "speed": 1, 123 | "queue_length": self.get_queue_length(), 124 | } 125 | 126 | async def generate_stream(self, params): 127 | ori_prompt = prompt = params["prompt"] 128 | images = params.get("images", None) 129 | if images is not None and len(images) > 0: 130 | if len(images) > 0: 131 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 132 | raise ValueError("Number of images does not match number of tokens in prompt") 133 | 134 | images = [load_image_from_base64(image) for image in images] 135 | # FIXME: hacky padding 136 | images = [expand2square(image, tuple(int(x * 255) for x in [0.48145466, 0.4578275, 0.40821073])) for image in images] 137 | 138 | # FIXME: for image-start/end token 139 | # replace_token = DEFAULT_IMAGE_TOKEN 140 | # if getattr(self.model.config, 'mm_use_im_start_end', False): 141 | # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 142 | # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 143 | prompt = prompt.replace(" " + DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN) 144 | prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN) 145 | prompt = [] 146 | for i in range(len(prompt_split)): 147 | prompt.append(prompt_split[i]) 148 | if i < len(images): 149 | prompt.append(images[i]) 150 | else: 151 | prompt = [prompt] 152 | 153 | temperature = float(params.get("temperature", 1.0)) 154 | top_p = float(params.get("top_p", 1.0)) 155 | # max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 156 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 157 | stop_str = params.get("stop", None) 158 | stop_str = [stop_str] if stop_str is not None else None 159 | 160 | if max_new_tokens < 1: 161 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 162 | return 163 | 164 | # print(prompt) 165 | state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True) 166 | 167 | generated_text = ori_prompt 168 | async for text_outputs in state.text_async_iter(var_name="response"): 169 | generated_text += text_outputs 170 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 171 | 172 | async def generate_stream_gate(self, params): 173 | try: 174 | async for x in self.generate_stream(params): 175 | yield x 176 | except ValueError as e: 177 | print("Caught ValueError:", e) 178 | ret = { 179 | "text": server_error_msg, 180 | "error_code": 1, 181 | } 182 | yield json.dumps(ret).encode() + b"\0" 183 | except Exception as e: 184 | print("Caught Unknown Error", e) 185 | ret = { 186 | "text": server_error_msg, 187 | "error_code": 1, 188 | } 189 | yield json.dumps(ret).encode() + b"\0" 190 | 191 | 192 | app = FastAPI() 193 | 194 | 195 | def release_model_semaphore(fn=None): 196 | model_semaphore.release() 197 | if fn is not None: 198 | fn() 199 | 200 | 201 | @app.post("/worker_generate_stream") 202 | async def generate_stream(request: Request): 203 | global model_semaphore, global_counter 204 | global_counter += 1 205 | params = await request.json() 206 | 207 | if model_semaphore is None: 208 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 209 | await model_semaphore.acquire() 210 | worker.send_heart_beat() 211 | generator = worker.generate_stream_gate(params) 212 | background_tasks = BackgroundTasks() 213 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 214 | return StreamingResponse(generator, background=background_tasks) 215 | 216 | 217 | @app.post("/worker_get_status") 218 | async def get_status(request: Request): 219 | return worker.get_status() 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument("--host", type=str, default="localhost") 225 | parser.add_argument("--port", type=int, default=21002) 226 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002") 227 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 228 | parser.add_argument("--model-name", type=str) 229 | parser.add_argument("--sgl-endpoint", type=str) 230 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 231 | parser.add_argument("--stream-interval", type=int, default=1) 232 | parser.add_argument("--no-register", action="store_true") 233 | args = parser.parse_args() 234 | logger.info(f"args: {args}") 235 | 236 | worker = ModelWorker(args.controller_address, args.worker_address, args.sgl_endpoint, worker_id, args.no_register, args.model_name) 237 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 238 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) 21 | worker_addr = ret.json()["address"] 22 | print(f"worker_addr: {worker_addr}") 23 | 24 | if worker_addr == "": 25 | return 26 | 27 | conv = default_conversation.copy() 28 | conv.append_message(conv.roles[0], args.message) 29 | prompt = conv.get_prompt() 30 | 31 | headers = {"User-Agent": "LLaVA Client"} 32 | pload = { 33 | "model": args.model_name, 34 | "prompt": prompt, 35 | "max_new_tokens": args.max_new_tokens, 36 | "temperature": 0.7, 37 | "stop": conv.sep, 38 | } 39 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) 40 | 41 | print(prompt.replace(conv.sep, "\n"), end="") 42 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 43 | if chunk: 44 | data = json.loads(chunk.decode("utf-8")) 45 | output = data["text"].split(conv.sep)[-1] 46 | print(output, end="\r") 47 | print("") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 53 | parser.add_argument("--worker-address", type=str) 54 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 55 | parser.add_argument("--max-new-tokens", type=int, default=32) 56 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") 57 | args = parser.parse_args() 58 | 59 | main() 60 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | padding_mask: Optional[torch.Tensor] = None, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | if output_attentions: 27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 28 | 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) 34 | 35 | kv_seq_len = key_states.shape[-2] 36 | if past_key_value is not None: 37 | kv_seq_len += past_key_value[0].shape[-2] 38 | 39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 41 | 42 | if past_key_value is not None: 43 | # reuse k, v 44 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 45 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 46 | 47 | past_key_value = (key_states, value_states) if use_cache else None 48 | 49 | # repeat k/v heads if n_kv_heads < n_heads 50 | key_states = repeat_kv(key_states, self.num_key_value_groups) 51 | value_states = repeat_kv(value_states, self.num_key_value_groups) 52 | 53 | # Transform the data into the format required by flash attention 54 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 56 | key_padding_mask = attention_mask 57 | 58 | if key_padding_mask is None: 59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 61 | max_s = q_len 62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 63 | output = output.view(bsz, q_len, -1) 64 | else: 65 | qkv = qkv.reshape(bsz, q_len, -1) 66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 70 | output = pad_input(output_unpad, indices, bsz, q_len) 71 | 72 | return self.o_proj(output), None, past_key_value 73 | 74 | 75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 76 | # requires the attention mask to be the same as the key_padding_mask 77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 78 | # [bsz, seq_len] 79 | return attention_mask 80 | 81 | 82 | def replace_llama_attn_with_flash_attn(): 83 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 84 | if cuda_major < 8: 85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") 86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 88 | -------------------------------------------------------------------------------- /llava/train/llava_trainer_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from llava.train.llava_trainer import LLaVATrainer 5 | 6 | 7 | class LLaVAEvalTrainer(LLaVATrainer): 8 | def evaluate(self, evaluate_args): 9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ 10 | --model {evaluate_args.model} \ 11 | --model_args {evaluate_args.model_args} \ 12 | --tasks {evaluate_args.task_names} \ 13 | --batch_size {evaluate_args.batch_size} \ 14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \ 15 | --output_path {evaluate_args.output_path}" 16 | if evaluate_args.limit: 17 | cmd += f" --limit {evaluate_args.limit}" 18 | if evaluate_args.num_fewshot: 19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}" 20 | if evaluate_args.gen_kwargs != "": 21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" 22 | if evaluate_args.log_samples: 23 | cmd += f" --log_samples" 24 | else: 25 | assert False, "Please log samples so that the result can be parsed" 26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True) 27 | try: 28 | result_file_index_start = results.stdout.index("Saved samples to ") 29 | result_file_index_end = results.stdout.index(f".json") 30 | result_file_index_start += len("Saved samples to ") 31 | file = results.stdout[result_file_index_start:result_file_index_end] 32 | except: 33 | result_file_index_start = results.stderr.index("Saved samples to ") 34 | result_file_index_end = results.stderr.index(f".json") 35 | result_file_index_start += len("Saved samples to ") 36 | file = results.stderr[result_file_index_start:result_file_index_end] 37 | file = file.split("/")[:-1] 38 | file = "/".join(file) + "/results.json" 39 | with open(file, "r") as f: 40 | lmms_eval_results = json.load(f) 41 | result_dict = {} 42 | tasks_list = evaluate_args.task_names.split(",") 43 | for task in tasks_list: 44 | task_results = lmms_eval_results["results"][task] 45 | for k, v in task_results.items(): 46 | if k != "alias" and "stderr" not in k: 47 | metric = k.split(",")[0] 48 | result_dict[f"{task}_{metric}"] = v 49 | return result_dict 50 | 51 | """def evaluate(self, evaluate_args): 52 | initialize_tasks() 53 | tasks_list = evaluate_args.task_names.split(",") 54 | result_dict = {} 55 | results = evaluator.simple_evaluate( 56 | model=evaluate_args.model, 57 | model_args=evaluate_args.model_args, 58 | tasks=tasks_list, 59 | num_fewshot=evaluate_args.num_fewshot, 60 | batch_size=evaluate_args.batch_size, 61 | device=evaluate_args.device, 62 | limit=evaluate_args.limit, 63 | check_integrity=evaluate_args.check_integrity, 64 | show_task_to_terminal=evaluate_args.show_task_to_terminal, 65 | log_samples=evaluate_args.log_samples, 66 | gen_kwargs=evaluate_args.gen_kwargs, 67 | cli_args=evaluate_args, 68 | ) 69 | for task in tasks_list: 70 | task_results = results["results"][task] 71 | for k,v in task_results.items(): 72 | if k != "alias" and "stderr" not in k: 73 | metric = k.split(",")[0] 74 | result_dict[f"{task}_{metric}"] = v 75 | 76 | return result_dict""" 77 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train() 5 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | import requests 9 | 10 | from llava.constants import LOGDIR 11 | 12 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content." 14 | 15 | handler = None 16 | 17 | import torch.distributed as dist 18 | 19 | try: 20 | import av 21 | from decord import VideoReader, cpu 22 | except ImportError: 23 | print("Please install pyav to use video processing functions.") 24 | 25 | def process_video_with_decord(video_file, data_args): 26 | vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) 27 | total_frame_num = len(vr) 28 | video_time = total_frame_num / vr.get_avg_fps() 29 | avg_fps = round(vr.get_avg_fps() / data_args.video_fps) 30 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)] 31 | frame_time = [i/avg_fps for i in frame_idx] 32 | 33 | 34 | if data_args.frames_upbound > 0: 35 | if len(frame_idx) > data_args.frames_upbound or data_args.force_sample: 36 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) 37 | frame_idx = uniform_sampled_frames.tolist() 38 | frame_time = [i/vr.get_avg_fps() for i in frame_idx] 39 | 40 | video = vr.get_batch(frame_idx).asnumpy() 41 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) 42 | 43 | num_frames_to_sample = num_frames = len(frame_idx) 44 | # https://github.com/dmlc/decord/issues/208 45 | vr.seek(0) 46 | return video, video_time, frame_time, num_frames_to_sample 47 | 48 | def process_video_with_pyav(video_file, data_args): 49 | container = av.open(video_file) 50 | # !!! This is the only difference. Using auto threading 51 | container.streams.video[0].thread_type = "AUTO" 52 | 53 | video_frames = [] 54 | for packet in container.demux(): 55 | if packet.stream.type == 'video': 56 | for frame in packet.decode(): 57 | video_frames.append(frame) 58 | total_frame_num = len(video_frames) 59 | video_time = video_frames[-1].time 60 | avg_fps = round(total_frame_num / video_time / data_args.video_fps) 61 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)] 62 | 63 | if data_args.frames_upbound > 0: 64 | if len(frame_idx) > data_args.frames_upbound: 65 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) 66 | frame_idx = uniform_sampled_frames.tolist() 67 | 68 | 69 | frames = [video_frames[i] for i in frame_idx] 70 | return np.stack([x.to_ndarray(format="rgb24") for x in frames]) 71 | 72 | 73 | def rank0_print(*args): 74 | if dist.is_initialized(): 75 | if dist.get_rank() == 0: 76 | print(f"Rank {dist.get_rank()}: ", *args) 77 | else: 78 | print(*args) 79 | 80 | 81 | def rank_print(*args): 82 | if dist.is_initialized(): 83 | print(f"Rank {dist.get_rank()}: ", *args) 84 | else: 85 | print(*args) 86 | 87 | def build_logger(logger_name, logger_filename): 88 | global handler 89 | 90 | formatter = logging.Formatter( 91 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 92 | datefmt="%Y-%m-%d %H:%M:%S", 93 | ) 94 | 95 | # Set the format of root handlers 96 | if not logging.getLogger().handlers: 97 | logging.basicConfig(level=logging.INFO) 98 | logging.getLogger().handlers[0].setFormatter(formatter) 99 | 100 | # Redirect stdout and stderr to loggers 101 | stdout_logger = logging.getLogger("stdout") 102 | stdout_logger.setLevel(logging.INFO) 103 | sl = StreamToLogger(stdout_logger, logging.INFO) 104 | sys.stdout = sl 105 | 106 | stderr_logger = logging.getLogger("stderr") 107 | stderr_logger.setLevel(logging.ERROR) 108 | sl = StreamToLogger(stderr_logger, logging.ERROR) 109 | sys.stderr = sl 110 | 111 | # Get logger 112 | logger = logging.getLogger(logger_name) 113 | logger.setLevel(logging.INFO) 114 | 115 | # Add a file handler for all loggers 116 | if handler is None: 117 | os.makedirs(LOGDIR, exist_ok=True) 118 | filename = os.path.join(LOGDIR, logger_filename) 119 | handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True) 120 | handler.setFormatter(formatter) 121 | 122 | for name, item in logging.root.manager.loggerDict.items(): 123 | if isinstance(item, logging.Logger): 124 | item.addHandler(handler) 125 | 126 | return logger 127 | 128 | 129 | class StreamToLogger(object): 130 | """ 131 | Fake file-like stream object that redirects writes to a logger instance. 132 | """ 133 | 134 | def __init__(self, logger, log_level=logging.INFO): 135 | self.terminal = sys.stdout 136 | self.logger = logger 137 | self.log_level = log_level 138 | self.linebuf = "" 139 | 140 | def __getattr__(self, attr): 141 | return getattr(self.terminal, attr) 142 | 143 | def write(self, buf): 144 | temp_linebuf = self.linebuf + buf 145 | self.linebuf = "" 146 | for line in temp_linebuf.splitlines(True): 147 | # From the io.TextIOWrapper docs: 148 | # On output, if newline is None, any '\n' characters written 149 | # are translated to the system default line separator. 150 | # By default sys.stdout.write() expects '\n' newlines and then 151 | # translates them so this is still cross platform. 152 | if line[-1] == "\n": 153 | self.logger.log(self.log_level, line.rstrip()) 154 | else: 155 | self.linebuf += line 156 | 157 | def flush(self): 158 | if self.linebuf != "": 159 | self.logger.log(self.log_level, self.linebuf.rstrip()) 160 | self.linebuf = "" 161 | 162 | 163 | def disable_torch_init(): 164 | """ 165 | Disable the redundant torch default initialization to accelerate model creation. 166 | """ 167 | import torch 168 | 169 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 170 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 171 | 172 | 173 | def violates_moderation(text): 174 | """ 175 | Check whether the text violates OpenAI moderation API. 176 | """ 177 | url = "https://api.openai.com/v1/moderations" 178 | headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 179 | text = text.replace("\n", "") 180 | data = "{" + '"input": ' + f'"{text}"' + "}" 181 | data = data.encode("utf-8") 182 | try: 183 | ret = requests.post(url, headers=headers, data=data, timeout=5) 184 | flagged = ret.json()["results"][0]["flagged"] 185 | except requests.exceptions.RequestException as e: 186 | print(f"######################### Moderation Error: {e} #########################") 187 | flagged = False 188 | except KeyError as e: 189 | print(f"######################### Moderation Error: {e} #########################") 190 | flagged = False 191 | 192 | return flagged 193 | 194 | 195 | def pretty_print_semaphore(semaphore): 196 | if semaphore is None: 197 | return "None" 198 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 199 | -------------------------------------------------------------------------------- /metrics/get_mmlu_acc.py: -------------------------------------------------------------------------------- 1 | import json 2 | from statistics import mean 3 | 4 | 5 | def load_json(file_path): 6 | with open(file_path, "r") as file: 7 | return json.load(file) 8 | 9 | 10 | def extract_exact_match_values(data): 11 | exact_match_values = [] 12 | 13 | def traverse_dict(d): 14 | for key, value in d.items(): 15 | if isinstance(value, dict): 16 | if "exact_match,get_response" in value: 17 | exact_match_values.append(value["exact_match,get_response"]) 18 | traverse_dict(value) 19 | 20 | traverse_dict(data) 21 | return exact_match_values 22 | 23 | 24 | def main(): 25 | file_path = "./mmlu_log/Path_to_results.json" 26 | data = load_json(file_path) 27 | exact_match_values = extract_exact_match_values(data) 28 | 29 | if exact_match_values: 30 | average = mean(exact_match_values) 31 | print(f"Average of exact_match,get_response: {average:.4f}") 32 | print(f"Number of values: {len(exact_match_values)}") 33 | else: 34 | print("No exact_match,get_response values found.") 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /metrics/humaneval_pass@1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import ast 5 | import traceback 6 | import evaluate as hf_evaluate 7 | from typing import Dict, List, Optional, Set, Tuple 8 | 9 | ROOT = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))]) 11 | 12 | 13 | def refine_text(text: str) -> str: 14 | text = text.replace("\t", " ") 15 | text = text.replace("\r\n", "\n").replace("\r", "\n") 16 | return text.strip() + "\n" 17 | 18 | 19 | def syntax_check(code, verbose=False): 20 | try: 21 | ast.parse(code) 22 | return True 23 | except (SyntaxError, MemoryError): 24 | if verbose: 25 | traceback.print_exc() 26 | return False 27 | 28 | 29 | def extract_longest_valid_code(text: str) -> str: 30 | lines = text.splitlines() 31 | if len(lines) > 100: 32 | lines = lines[:100] 33 | max_valid_lines = 0 34 | max_valid_snippet = "" 35 | 36 | for i in range(len(lines)): 37 | for j in range(i, len(lines)): 38 | current_snippet = "\n".join(lines[i : j + 1]) 39 | if syntax_check(current_snippet): 40 | valid_line_count = sum(1 for line in lines[i : j + 1] if line.strip()) 41 | if valid_line_count > max_valid_lines: 42 | max_valid_lines = valid_line_count 43 | max_valid_snippet = current_snippet 44 | return max_valid_snippet 45 | 46 | 47 | def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]: 48 | name2deps = {} 49 | for name, node in nodes: 50 | deps = set() 51 | stack = [node] 52 | while stack: 53 | current = stack.pop() 54 | for child in ast.iter_child_nodes(current): 55 | if isinstance(child, ast.Name): 56 | deps.add(child.id) 57 | elif isinstance(child, ast.Attribute): 58 | deps.add(child.attr) 59 | else: 60 | stack.append(child) 61 | name2deps[name] = deps 62 | return name2deps 63 | 64 | 65 | def get_function_dependency( 66 | entrypoint: str, call_graph: Dict[str, Set[str]] 67 | ) -> Set[str]: 68 | visited = set() 69 | to_visit = [entrypoint] 70 | while to_visit: 71 | current = to_visit.pop(0) 72 | if current not in visited: 73 | visited.add(current) 74 | to_visit.extend(call_graph.get(current, set()) - visited) 75 | return visited 76 | 77 | 78 | def get_definition_name(node: ast.AST) -> Optional[str]: 79 | if isinstance(node, (ast.FunctionDef, ast.ClassDef)): 80 | return node.name 81 | elif isinstance(node, ast.Assign): 82 | targets = node.targets 83 | if targets and isinstance(targets[0], ast.Name): 84 | return targets[0].id 85 | return None 86 | 87 | 88 | def has_return_statement(node: ast.AST) -> bool: 89 | return any(isinstance(n, ast.Return) for n in ast.walk(node)) 90 | 91 | 92 | def sanitize(text: str, entrypoint: Optional[str] = None) -> str: 93 | text = refine_text(text) 94 | code = extract_longest_valid_code(text) 95 | tree = ast.parse(code) 96 | 97 | definitions = {} 98 | imports = [] 99 | 100 | for node in tree.body: 101 | if isinstance(node, (ast.Import, ast.ImportFrom)): 102 | imports.append(node) 103 | elif isinstance(node, ast.ClassDef): 104 | name = node.name 105 | definitions[name] = ("class", node) 106 | elif isinstance(node, ast.FunctionDef): 107 | name = node.name 108 | if has_return_statement(node): 109 | definitions[name] = ("function", node) 110 | elif isinstance(node, ast.Assign): 111 | name = get_definition_name(node) 112 | if name: 113 | definitions[name] = ("variable", node) 114 | 115 | if entrypoint: 116 | name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()]) 117 | reachable = get_function_dependency(entrypoint, name2deps) 118 | 119 | sanitized_output = [] 120 | for node in imports: 121 | sanitized_output.append(ast.unparse(node)) 122 | 123 | for name, (_, node) in definitions.items(): 124 | if not entrypoint or name in reachable: 125 | sanitized_output.append(ast.unparse(node)) 126 | 127 | return "\n".join(sanitized_output) 128 | 129 | 130 | os.environ["HF_ALLOW_CODE_EVAL"] = "1" 131 | pass_at_k = hf_evaluate.load("code_eval") 132 | 133 | 134 | def pass_at_1(references, predictions): 135 | return pass_at_k.compute(references=references, predictions=predictions, k=[1])[0][ 136 | "pass@1" 137 | ] 138 | 139 | 140 | def read_jsonl(file_path): 141 | data = [] 142 | with open(file_path, "r") as file: 143 | for line in file: 144 | data.append(json.loads(line)) 145 | return data 146 | 147 | 148 | def main(): 149 | file_path = sys.argv[1] 150 | data = read_jsonl(file_path) 151 | 152 | references = [sample["target"] for sample in data] 153 | predictions = [ 154 | [ 155 | sanitize( 156 | sample["doc"]["prompt"] 157 | + "\n" 158 | + sample["resps"][0][0].split("```python\n", 1)[-1].split("```")[0], 159 | sample["doc"]["entry_point"], 160 | ) 161 | ] 162 | for sample in data 163 | ] 164 | pass_at_1s = [ 165 | pass_at_1([reference], [prediction]) 166 | for reference, prediction in zip(references, predictions) 167 | ] 168 | print("PASS@1:", sum(pass_at_1s) / len(pass_at_1s)) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | datasets 3 | tqdm 4 | zeno_client 5 | peft 6 | accelerate 7 | lm-eval 8 | datasets 9 | numpy 10 | torchmetrics 11 | transformers 12 | deepspeed 13 | huggingface_hub[hf_xet] 14 | calflops 15 | lm-eval[longbench] 16 | lm-eval[math] -------------------------------------------------------------------------------- /scripts/run_Dream_bbh_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for bbh" 11 | 12 | # --- Task Specific Parameters for bbh --- 13 | TASK="bbh" 14 | NUM_FEWSHOT=3 # From tasks="... bbh", nshots="... 3" 15 | MAX_NEW_TOKENS=256 # From tasks="... bbh", lengths="... 512" 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... bbh", temperatures="... 0" 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=10,gen_interval_steps=2,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_bbh_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for bbh" 11 | 12 | # --- Task Specific Parameters for bbh --- 13 | TASK="bbh" 14 | NUM_FEWSHOT=3 # From tasks="... bbh", nshots="... 3" 15 | MAX_NEW_TOKENS=256 # From tasks="... bbh", lengths="... 512" 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... bbh", temperatures="... 0" 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=25,gen_interval_steps=4,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_gpqa_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for gpqa_main_generative_n_shot" 11 | 12 | # --- Task Specific Parameters for gpqa_main_generative_n_shot --- 13 | TASK="gpqa_main_generative_n_shot" 14 | NUM_FEWSHOT=5 # From tasks="... gpqa_main_generative_n_shot ...", nshots="... 4 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="... gpqa_main_generative_n_shot ...", lengths="... 512 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... gpqa_main_generative_n_shot ...", temperatures="... 0 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code \ 32 | --apply_chat_template \ 33 | --fewshot_as_multiturn \ 34 | 35 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 36 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=10,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 37 | --tasks ${TASK} \ 38 | --num_fewshot ${NUM_FEWSHOT} \ 39 | --batch_size 2 \ 40 | --output_path ${OUTPUT_PATH} \ 41 | --log_samples \ 42 | --confirm_run_unsafe_code \ 43 | --apply_chat_template \ 44 | --fewshot_as_multiturn \ 45 | 46 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_gpqa_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for gpqa_main_generative_n_shot" 11 | 12 | # --- Task Specific Parameters for gpqa_main_generative_n_shot --- 13 | TASK="gpqa_main_generative_n_shot" 14 | NUM_FEWSHOT=5 # From tasks="... gpqa_main_generative_n_shot ...", nshots="... 4 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="... gpqa_main_generative_n_shot ...", lengths="... 512 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... gpqa_main_generative_n_shot ...", temperatures="... 0 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=100,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_gsm8k_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for gsm8k_cot" 11 | 12 | # --- Task Specific Parameters for gsm8k_cot --- 13 | TASK="gsm8k_cot" 14 | NUM_FEWSHOT=8 # From tasks="gsm8k_cot ...", nshots="8 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="gsm8k_cot ...", lengths="256 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="gsm8k_cot ...", temperatures="0 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=25,gen_interval_steps=2,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_gsm8k_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for gsm8k_cot" 11 | 12 | # --- Task Specific Parameters for gsm8k_cot --- 13 | TASK="gsm8k_cot" 14 | NUM_FEWSHOT=8 # From tasks="gsm8k_cot ...", nshots="8 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="gsm8k_cot ...", lengths="256 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="gsm8k_cot ...", temperatures="0 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=100,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_humaneval_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for humaneval" 11 | 12 | # --- Task Specific Parameters for humaneval --- 13 | TASK="humaneval" 14 | NUM_FEWSHOT=0 15 | MAX_NEW_TOKENS=256 16 | DIFFUSION_STEPS=256 # Note: based on original script 17 | TEMPERATURE=0.2 18 | TOP_P=0.95 19 | ADD_BOS_TOKEN="true" 20 | ESCAPE_UNTIL="true" # Note: specific to the humaneval run in original script 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=50,gen_interval_steps=1,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" 43 | 44 | ### NOTICE: use postprocess for humaneval 45 | # python postprocess_code.py {the samples_xxx.jsonl file under output_path} -------------------------------------------------------------------------------- /scripts/run_Dream_humaneval_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for humaneval" 11 | 12 | # --- Task Specific Parameters for humaneval --- 13 | TASK="humaneval" 14 | NUM_FEWSHOT=0 15 | MAX_NEW_TOKENS=256 16 | DIFFUSION_STEPS=256 # Note: based on original script 17 | TEMPERATURE=0.2 18 | TOP_P=0.95 19 | ADD_BOS_TOKEN="true" 20 | ESCAPE_UNTIL="true" # Note: specific to the humaneval run in original script 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=5,gen_interval_steps=1,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" 43 | 44 | ### NOTICE: use postprocess for humaneval 45 | # python postprocess_code.py {the samples_xxx.jsonl file under output_path} -------------------------------------------------------------------------------- /scripts/run_Dream_mbpp_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for mbpp" 11 | 12 | # --- Task Specific Parameters for mbpp --- 13 | TASK="mbpp" 14 | NUM_FEWSHOT=3 # From tasks="... mbpp ...", nshots="... 3 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="... mbpp ...", lengths="... 512 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... mbpp ...", temperatures="... 0.2 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=10,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_mbpp_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for mbpp" 11 | 12 | # --- Task Specific Parameters for mbpp --- 13 | TASK="mbpp" 14 | NUM_FEWSHOT=3 # From tasks="... mbpp ...", nshots="... 3 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="... mbpp ...", lengths="... 512 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... mbpp ...", temperatures="... 0.2 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=25,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_minerva_math_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for minerva_math" 11 | 12 | # --- Task Specific Parameters for minerva_math --- 13 | TASK="minerva_math" 14 | NUM_FEWSHOT=4 # From tasks="... minerva_math ...", nshots="... 4 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="... minerva_math ...", lengths="... 512 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... minerva_math ...", temperatures="... 0 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0.0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=50,gen_interval_steps=1,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_minerva_math_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for minerva_math" 11 | 12 | # --- Task Specific Parameters for minerva_math --- 13 | TASK="minerva_math" 14 | NUM_FEWSHOT=4 # From tasks="... minerva_math ...", nshots="... 4 ..." 15 | MAX_NEW_TOKENS=256 # From tasks="... minerva_math ...", lengths="... 512 ..." 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... minerva_math ...", temperatures="... 0 ..." 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=100,gen_interval_steps=4,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_mmlu_generative_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for mmlu_generative" 11 | 12 | # --- Task Specific Parameters for bbh --- 13 | TASK="mmlu_generative" 14 | NUM_FEWSHOT=5 # From tasks="... bbh", nshots="... 3" 15 | MAX_NEW_TOKENS=256 # From tasks="... bbh", lengths="... 512" 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... bbh", temperatures="... 0" 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code \ 32 | --apply_chat_template \ 33 | --fewshot_as_multiturn \ 34 | 35 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 36 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=100,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 37 | --tasks ${TASK} \ 38 | --num_fewshot ${NUM_FEWSHOT} \ 39 | --batch_size 2 \ 40 | --output_path ${OUTPUT_PATH} \ 41 | --log_samples \ 42 | --confirm_run_unsafe_code \ 43 | --apply_chat_template \ 44 | --fewshot_as_multiturn \ 45 | 46 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_mmlu_generative_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for mmlu_generative" 11 | 12 | # --- Task Specific Parameters for bbh --- 13 | TASK="mmlu_generative" 14 | NUM_FEWSHOT=5 # From tasks="... bbh", nshots="... 3" 15 | MAX_NEW_TOKENS=256 # From tasks="... bbh", lengths="... 512" 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... bbh", temperatures="... 0" 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code \ 32 | 33 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 34 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=100,gen_interval_steps=2,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 35 | --tasks ${TASK} \ 36 | --num_fewshot ${NUM_FEWSHOT} \ 37 | --batch_size 2 \ 38 | --output_path ${OUTPUT_PATH} \ 39 | --log_samples \ 40 | --confirm_run_unsafe_code \ 41 | 42 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_mmlu_pro_Instruct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Instruct-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for mmlu_pro" 11 | 12 | # --- Task Specific Parameters for bbh --- 13 | TASK="mmlu_pro" 14 | NUM_FEWSHOT=0 # From tasks="... bbh", nshots="... 3" 15 | MAX_NEW_TOKENS=256 # From tasks="... bbh", lengths="... 512" 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... bbh", temperatures="... 0" 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=5,gen_interval_steps=1,cfg_interval_steps=1,transfer_ratio=0.0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code \ 32 | --trust_remote_code \ 33 | --apply_chat_template \ 34 | --fewshot_as_multiturn \ 35 | 36 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 37 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,prompt_interval_steps=5,gen_interval_steps=1,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 38 | --tasks ${TASK} \ 39 | --num_fewshot ${NUM_FEWSHOT} \ 40 | --batch_size 2 \ 41 | --output_path ${OUTPUT_PATH} \ 42 | --log_samples \ 43 | --confirm_run_unsafe_code \ 44 | --trust_remote_code \ 45 | --apply_chat_template \ 46 | --fewshot_as_multiturn \ 47 | 48 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_Dream_mmlu_pro_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="Dream-org/Dream-v0-Base-7B" 4 | 5 | export HF_ALLOW_CODE_EVAL=1 6 | 7 | ACCEL_CONFIG="accelerate_config.yaml" 8 | MAIN_PORT="29510" 9 | 10 | echo "Starting evaluation for mmlu_pro" 11 | 12 | # --- Task Specific Parameters for bbh --- 13 | TASK="mmlu_pro" 14 | NUM_FEWSHOT=0 # From tasks="... bbh", nshots="... 3" 15 | MAX_NEW_TOKENS=256 # From tasks="... bbh", lengths="... 512" 16 | DIFFUSION_STEPS=256 # Note: based on original script (equal to max_new_tokens) 17 | TEMPERATURE=0.2 # From tasks="... bbh", temperatures="... 0" 18 | TOP_P=0.95 # Constant in the original loop's model_args 19 | ADD_BOS_TOKEN="true" # Constant in the original loop's model_args 20 | # Note: original loop did NOT include escape_until=true 21 | 22 | OUTPUT_PATH="./${TASK}_log" 23 | 24 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 25 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0.0,is_feature_cache=False,is_cfg_cache=False \ 26 | --tasks ${TASK} \ 27 | --num_fewshot ${NUM_FEWSHOT} \ 28 | --batch_size 2 \ 29 | --output_path ${OUTPUT_PATH} \ 30 | --log_samples \ 31 | --confirm_run_unsafe_code \ 32 | --trust_remote_code \ 33 | 34 | accelerate launch --config_file ${ACCEL_CONFIG} --main_process_port ${MAIN_PORT} evaluation_script.py --model dream \ 35 | --model_args pretrained=${model},max_new_tokens=${MAX_NEW_TOKENS},diffusion_steps=${DIFFUSION_STEPS},temperature=${TEMPERATURE},top_p=${TOP_P},alg="entropy",alg_temp=0.0,add_bos_token=${ADD_BOS_TOKEN},prompt_interval_steps=25,gen_interval_steps=2,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False \ 36 | --tasks ${TASK} \ 37 | --num_fewshot ${NUM_FEWSHOT} \ 38 | --batch_size 2 \ 39 | --output_path ${OUTPUT_PATH} \ 40 | --log_samples \ 41 | --confirm_run_unsafe_code \ 42 | --trust_remote_code \ 43 | 44 | 45 | echo "Completed evaluation for ${TASK}" -------------------------------------------------------------------------------- /scripts/run_LLaDA_bbh_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks bbh --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 3 \ 5 | --output_path ./bbh_log \ 6 | --log_samples \ 7 | --trust_remote_code \ 8 | --apply_chat_template \ 9 | --fewshot_as_multiturn \ 10 | 11 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks bbh --batch_size 2 \ 12 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=50,gen_interval_steps=6,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 13 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 14 | --num_fewshot 3 \ 15 | --output_path ./bbh_log \ 16 | --log_samples \ 17 | --trust_remote_code \ 18 | --apply_chat_template \ 19 | --fewshot_as_multiturn \ 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_bbh_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks bbh --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 3 \ 5 | --output_path ./bbh_log \ 6 | --log_samples \ 7 | --trust_remote_code \ 8 | 9 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks bbh --batch_size 2 \ 10 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=50,gen_interval_steps=6,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 11 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 12 | --num_fewshot 3 \ 13 | --output_path ./bbh_log \ 14 | --log_samples \ 15 | --trust_remote_code \ 16 | 17 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_gpqa_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gpqa_main_generative_n_shot --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=64,gen_length=128,steps=128,cfg_scale=0.0,remasking="low_confidence" " \ 4 | --num_fewshot 5 \ 5 | --output_path ./gpqa_log \ 6 | --log_samples \ 7 | --confirm_run_unsafe_code \ 8 | --trust_remote_code \ 9 | --apply_chat_template \ 10 | --fewshot_as_multiturn \ 11 | 12 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gpqa_main_generative_n_shot --batch_size 2 \ 13 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=50,gen_interval_steps=6,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 14 | --gen_kwargs "block_length=64,gen_length=128,steps=128,cfg_scale=0.0,remasking="low_confidence" " \ 15 | --num_fewshot 5 \ 16 | --output_path ./gpqa_log \ 17 | --log_samples \ 18 | --confirm_run_unsafe_code \ 19 | --trust_remote_code \ 20 | --apply_chat_template \ 21 | --fewshot_as_multiturn \ 22 | 23 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_gpqa_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gpqa_main_generative_n_shot --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0,remasking="low_confidence" " \ 4 | --num_fewshot 5 \ 5 | --output_path ./gpqa_log \ 6 | --log_samples \ 7 | --confirm_run_unsafe_code \ 8 | --trust_remote_code \ 9 | 10 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gpqa_main_generative_n_shot --batch_size 2 \ 11 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=100,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 12 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0,remasking="low_confidence" " \ 13 | --num_fewshot 5 \ 14 | --output_path ./gpqa_log \ 15 | --log_samples \ 16 | --confirm_run_unsafe_code \ 17 | --trust_remote_code \ -------------------------------------------------------------------------------- /scripts/run_LLaDA_gsm8k_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gsm8k --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,transfer_ratio=0,cache_order=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=8,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 4 \ 5 | --output_path ./gsm8k_log \ 6 | --log_samples \ 7 | --apply_chat_template \ 8 | --fewshot_as_multiturn \ 9 | 10 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gsm8k --batch_size 2 \ 11 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=50,gen_interval_steps=7,transfer_ratio=0.25,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 12 | --gen_kwargs "block_length=8,gen_length=256,steps=256,cfg_scale=0.0 " \ 13 | --num_fewshot 4 \ 14 | --output_path ./gsm8k_log \ 15 | --log_samples \ 16 | --apply_chat_template \ 17 | --fewshot_as_multiturn \ 18 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_gsm8k_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gsm8k --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,transfer_ratio=0.0,cache_order=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0 " \ 4 | --num_fewshot 4 \ 5 | --output_path ./gsm8k_log \ 6 | --log_samples \ 7 | 8 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gsm8k --batch_size 2 \ 9 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=100,gen_interval_steps=6,transfer_ratio=0.25,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 10 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0 " \ 11 | --num_fewshot 4 \ 12 | --output_path ./gsm8k_log \ 13 | --log_samples \ 14 | 15 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks gsm8k --batch_size 2 \ 16 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=25,gen_interval_steps=5,transfer_ratio=0.25,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 17 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0 " \ 18 | --num_fewshot 4 \ 19 | --output_path ./gsm8k_log \ 20 | --log_samples \ 21 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_humaneval_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks humaneval --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,transfer_ratio=0,cache_order=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=32,gen_length=512,steps=512,cfg_scale=0.0 " \ 4 | --output_path ./humaneval_log/ \ 5 | --log_samples \ 6 | --confirm_run_unsafe_code \ 7 | 8 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks humaneval --batch_size 2 \ 9 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=50,gen_interval_steps=8,transfer_ratio=0.25,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 10 | --gen_kwargs "block_length=32,gen_length=512,steps=512,cfg_scale=0.0 " \ 11 | --output_path ./humaneval_log/ \ 12 | --log_samples \ 13 | --confirm_run_unsafe_code \ 14 | 15 | 16 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks humaneval --batch_size 2 \ 17 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=25,gen_interval_steps=5,transfer_ratio=0.25,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 18 | --gen_kwargs "block_length=32,gen_length=512,steps=512,cfg_scale=0.0 " \ 19 | --output_path ./humaneval_log/ \ 20 | --log_samples \ 21 | --confirm_run_unsafe_code \ -------------------------------------------------------------------------------- /scripts/run_LLaDA_humaneval_base.sh: -------------------------------------------------------------------------------- 1 | export HF_ALLOW_CODE_EVAL="1" 2 | 3 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks humaneval --batch_size 2 \ 4 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,transfer_ratio=0,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 5 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 6 | --num_fewshot 0 \ 7 | --output_path ./humaneval_log/ \ 8 | --log_samples \ 9 | --confirm_run_unsafe_code 10 | 11 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks humaneval --batch_size 2 \ 12 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=50,gen_interval_steps=5,transfer_ratio=0.25,cache_order=0,is_feature_cache=True,is_cfg_cache=False" \ 13 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 14 | --num_fewshot 0 \ 15 | --output_path ./humaneval_log/ \ 16 | --log_samples \ 17 | --confirm_run_unsafe_code -------------------------------------------------------------------------------- /scripts/run_LLaDA_long_bench_Instruct.sh: -------------------------------------------------------------------------------- 1 | export HF_ALLOW_CODE_EVAL=1 2 | export HF_DATASETS_TRUST_REMOTE_CODE=true 3 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks longbench_hotpotqa --batch_size 1 \ 4 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 5 | --gen_kwargs "block_length=32,gen_length=32,steps=32,cfg_scale=0.0 " \ 6 | --num_fewshot 0 \ 7 | --output_path ./longbench_log \ 8 | --log_samples \ 9 | --apply_chat_template \ 10 | --fewshot_as_multiturn \ 11 | --trust_remote_code 12 | 13 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks longbench_hotpotqa --batch_size 1 \ 14 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=100,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 15 | --gen_kwargs "block_length=32,gen_length=32,steps=32,cfg_scale=0.0 " \ 16 | --num_fewshot 0 \ 17 | --output_path ./longbench_log \ 18 | --log_samples \ 19 | --apply_chat_template \ 20 | --fewshot_as_multiturn \ 21 | --trust_remote_code 22 | 23 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_mbpp_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mbpp --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cache_order=0,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=32,gen_length=512,steps=512,cfg_scale=0.0,remasking="low_confidence" " \ 4 | --num_fewshot 3 \ 5 | --output_path ./mbpp_log \ 6 | --log_samples \ 7 | --apply_chat_template \ 8 | --fewshot_as_multiturn \ 9 | --confirm_run_unsafe_code \ 10 | 11 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mbpp --batch_size 2 \ 12 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=100,gen_interval_steps=5,cache_order=0,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 13 | --gen_kwargs "block_length=32,gen_length=512,steps=512,cfg_scale=0.0,remasking="low_confidence" " \ 14 | --num_fewshot 3 \ 15 | --output_path ./mbpp_log \ 16 | --log_samples \ 17 | --apply_chat_template \ 18 | --fewshot_as_multiturn \ 19 | --confirm_run_unsafe_code \ 20 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_mbpp_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mbpp --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,cache_order=0,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=32,gen_length=256,steps=256,cfg_scale=0.0,remasking="low_confidence" " \ 4 | --num_fewshot 3 \ 5 | --output_path ./mbpp_log \ 6 | --log_samples \ 7 | --confirm_run_unsafe_code \ 8 | 9 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mbpp --batch_size 2 \ 10 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=50,gen_interval_steps=4,cache_order=0,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 11 | --gen_kwargs "block_length=32,gen_length=256,steps=256,cfg_scale=0.0,remasking="low_confidence" " \ 12 | --num_fewshot 3 \ 13 | --output_path ./mbpp_log \ 14 | --log_samples \ 15 | --confirm_run_unsafe_code \ 16 | 17 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mbpp --batch_size 2 \ 18 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=25,gen_interval_steps=4,cache_order=0,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 19 | --gen_kwargs "block_length=32,gen_length=256,steps=256,cfg_scale=0.0,remasking="low_confidence" " \ 20 | --num_fewshot 3 \ 21 | --output_path ./mbpp_log \ 22 | --log_samples \ 23 | --confirm_run_unsafe_code \ -------------------------------------------------------------------------------- /scripts/run_LLaDA_minerva_math_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks minerva_math --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 0 \ 5 | --output_path ./minerva_math_log \ 6 | --log_samples \ 7 | --apply_chat_template \ 8 | --fewshot_as_multiturn \ 9 | 10 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks minerva_math --batch_size 2 \ 11 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=50,gen_interval_steps=1,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 12 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 13 | --num_fewshot 0 \ 14 | --output_path ./minerva_math_log \ 15 | --log_samples \ 16 | --apply_chat_template \ 17 | --fewshot_as_multiturn \ -------------------------------------------------------------------------------- /scripts/run_LLaDA_minerva_math_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks minerva_math --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 4 \ 5 | --output_path ./minerva_math_log \ 6 | --log_samples \ 7 | 8 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks minerva_math --batch_size 2 \ 9 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=50,gen_interval_steps=8,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 10 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 11 | --num_fewshot 4 \ 12 | --output_path ./minerva_math_log \ 13 | --log_samples \ 14 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_mmlu_generative_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_generative --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0.0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 5 \ 5 | --output_path ./mmlu_log \ 6 | --log_samples \ 7 | --apply_chat_template \ 8 | --fewshot_as_multiturn \ 9 | 10 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_generative --batch_size 2 \ 11 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=100,gen_interval_steps=7,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 12 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 13 | --num_fewshot 5 \ 14 | --output_path ./mmlu_log \ 15 | --log_samples \ 16 | --apply_chat_template \ 17 | --fewshot_as_multiturn \ 18 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_mmlu_generative_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_generative --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,cfg_interval_steps=-1,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 5 \ 5 | --output_path ./mmlu_log \ 6 | --log_samples \ 7 | 8 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_generative --batch_size 2 \ 9 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=100,gen_interval_steps=6,cfg_interval_steps=1,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 10 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 11 | --num_fewshot 5 \ 12 | --output_path ./mmlu_log \ 13 | --log_samples \ -------------------------------------------------------------------------------- /scripts/run_LLaDA_mmlu_pro_Instruct.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_pro --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=-1,gen_interval_steps=-1,cache_order=0,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 0 \ 5 | --output_path ./mmlu_pro_log \ 6 | --log_samples \ 7 | --apply_chat_template \ 8 | --fewshot_as_multiturn \ 9 | 10 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_pro --batch_size 2 \ 11 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,prompt_interval_steps=51,gen_interval_steps=3,cache_order=0,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 12 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 13 | --num_fewshot 0 \ 14 | --output_path ./mmlu_pro_log \ 15 | --log_samples \ 16 | --apply_chat_template \ 17 | --fewshot_as_multiturn \ 18 | -------------------------------------------------------------------------------- /scripts/run_LLaDA_mmlu_pro_base.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_pro --batch_size 2 \ 2 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=-1,gen_interval_steps=-1,cache_order=0,transfer_ratio=0,is_feature_cache=False,is_cfg_cache=False" \ 3 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 4 | --num_fewshot 0 \ 5 | --output_path ./mmlu_pro_log \ 6 | --log_samples \ 7 | 8 | accelerate launch --config_file accelerate_config.yaml evaluation_script.py -m lm_eval --model LLaDA --tasks mmlu_pro --batch_size 2 \ 9 | --model_args "pretrained=GSAI-ML/LLaDA-8B-Base,prompt_interval_steps=100,gen_interval_steps=6,cache_order=0,transfer_ratio=0.25,is_feature_cache=True,is_cfg_cache=False" \ 10 | --gen_kwargs "block_length=256,gen_length=256,steps=256,cfg_scale=0.0 " \ 11 | --num_fewshot 0 \ 12 | --output_path ./mmlu_pro_log \ 13 | --log_samples \ 14 | 15 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate_function import generate 2 | 3 | from .utils import set_seed 4 | -------------------------------------------------------------------------------- /utils/generate_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dllm_cache.cache import dLLMCache 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def add_gumbel_noise(logits, temperature): 8 | if temperature == 0: 9 | return logits.exp() 10 | noise = torch.rand_like(logits) 11 | gumbel_noise = (-torch.log(noise)) ** temperature 12 | return logits.exp() / gumbel_noise 13 | 14 | 15 | def get_num_transfer_tokens(mask_index, steps): 16 | mask_num = mask_index.sum(dim=1, keepdim=True) 17 | base = mask_num // steps 18 | remainder = mask_num % steps 19 | num_transfer_tokens = base.expand(-1, steps).clone() 20 | if remainder.sum() > 0: 21 | indices = torch.arange(steps, device=mask_index.device) 22 | mask = indices.unsqueeze(0) < remainder 23 | num_transfer_tokens[mask] += 1 24 | return num_transfer_tokens.to(torch.int64) 25 | 26 | 27 | def generate( 28 | input_ids, 29 | attention_mask, 30 | model, 31 | steps=128, 32 | gen_length=128, 33 | block_length=128, 34 | temperature=0.0, 35 | cfg_scale=0.0, 36 | remasking="low_confidence", 37 | mask_id=126336, 38 | ): 39 | with torch.no_grad(): 40 | batch_size, prompt_length = input_ids.shape 41 | x = torch.full( 42 | (batch_size, prompt_length + gen_length), 43 | mask_id, 44 | dtype=torch.long, 45 | device=model.device, 46 | ) 47 | x[:, :prompt_length] = input_ids 48 | 49 | prompt_index = x != mask_id 50 | 51 | assert gen_length % block_length == 0 52 | num_blocks = gen_length // block_length 53 | 54 | assert steps % num_blocks == 0 55 | steps_per_block = steps // num_blocks 56 | 57 | feature_cache = dLLMCache() 58 | feature_cache.reset_cache(prompt_length) 59 | for num_block in range(num_blocks): 60 | start_idx = prompt_length + num_block * block_length 61 | end_idx = prompt_length + (num_block + 1) * block_length 62 | 63 | block_x = x[:, start_idx:end_idx] 64 | block_mask_index = block_x == mask_id 65 | num_transfer_tokens = get_num_transfer_tokens( 66 | block_mask_index, steps_per_block 67 | ) 68 | 69 | for i in range(steps_per_block): 70 | mask_index = x == mask_id 71 | if cfg_scale > 0.0: 72 | if hasattr(feature_cache, "cfg_interval_steps"): 73 | feature_cache.update_step(layer_id=33) 74 | if feature_cache.refresh_cfg(layer_id=33): 75 | cfg_x = x.clone() 76 | cfg_x[prompt_index] = mask_id 77 | logits = model(x, attention_mask=attention_mask).logits[ 78 | :, prompt_length: 79 | ] 80 | feature_cache.cache_type = "cfg" 81 | cfg_logits = model( 82 | cfg_x, attention_mask=attention_mask 83 | ).logits[:, prompt_length:] 84 | cfg_residual = logits - cfg_logits 85 | feature_cache.set_cache( 86 | layer_id=33, 87 | feature_name="cfg_residual", 88 | features=cfg_residual, 89 | cache_type="gen", 90 | ) 91 | feature_cache.cache_type = "no_cfg" 92 | else: 93 | feature_cache.cache_type = "cfg" 94 | cfg_residual = feature_cache.get_cache( 95 | layer_id=33, 96 | feature_name="cfg_residual", 97 | cache_type="gen", 98 | ) 99 | feature_cache.cache_type = "no_cfg" 100 | logits = model(x, attention_mask=attention_mask).logits[ 101 | :, prompt_length: 102 | ] 103 | else: 104 | cfg_x = x.clone() 105 | cfg_x[prompt_index] = mask_id 106 | logits = model(x, attention_mask=attention_mask).logits[ 107 | :, prompt_length: 108 | ] 109 | cfg_logits = model(cfg_x, attention_mask=attention_mask).logits[ 110 | :, prompt_length: 111 | ] 112 | cfg_residual = logits - cfg_logits 113 | logits = (logits - cfg_residual) + (cfg_scale + 1) * cfg_residual 114 | else: 115 | logits = model(x, attention_mask=attention_mask).logits[ 116 | :, prompt_length: 117 | ] 118 | logits_with_noise = add_gumbel_noise(logits, temperature=temperature) 119 | 120 | x0 = torch.argmax(logits_with_noise, dim=-1) 121 | 122 | if remasking == "low_confidence": 123 | p = F.softmax(logits, dim=-1) 124 | x0_p = torch.squeeze( 125 | torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 126 | ) 127 | elif remasking == "random": 128 | x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) 129 | else: 130 | raise NotImplementedError(remasking) 131 | 132 | x0_p[:, (num_block + 1) * block_length :] = -np.inf 133 | 134 | x0 = torch.where( 135 | mask_index[:, prompt_length:], x0, x[:, prompt_length:] 136 | ) 137 | confidence = torch.where(mask_index[:, prompt_length:], x0_p, -np.inf) 138 | 139 | transfer_index = torch.zeros_like( 140 | x0, dtype=torch.bool, device=x0.device 141 | ) 142 | for j in range(confidence.shape[0]): 143 | select_index = torch.topk( 144 | confidence[j], k=num_transfer_tokens[j, i] 145 | ).indices 146 | transfer_index[j, select_index] = True 147 | x[:, prompt_length:][transfer_index] = x0[transfer_index] 148 | return x[:, prompt_length:] 149 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | def set_seed(seed): 5 | torch.manual_seed(seed) 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False --------------------------------------------------------------------------------