├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── INSTALL.md ├── README.md ├── assets ├── config1.jpg ├── config2.jpg ├── demos.png └── logo.png ├── lumina_mgpt ├── .isort.cfg ├── TRAIN.md ├── configs │ └── data │ │ └── sample.yaml ├── data │ ├── __init__.py │ ├── convertsation.py │ └── item_processor.py ├── demos │ ├── demo_freeform.py │ ├── demo_image2image.py │ └── demo_image_generation.py ├── exps │ └── 7B.sh ├── finetune_solver.py ├── generate_examples │ └── generate.py ├── inference_solver.py ├── model │ ├── __init__.py │ ├── chameleon │ │ ├── __init__.py │ │ ├── configuration_chameleon.py │ │ ├── convert_chameleon_weights_to_hf.py │ │ ├── image_processing_chameleon.py │ │ ├── modeling_chameleon.py │ │ └── processing_chameleon.py │ ├── chameleon_vae_ori │ │ ├── __init__.py │ │ ├── image_tokenizer.py │ │ ├── vocab.py │ │ └── vqgan.py │ ├── configuration_xllmx_chameleon.py │ └── modeling_xllmx_chameleon.py └── pre_tokenize │ ├── concat_record.py │ └── pre_tokenize.py ├── requirements.txt ├── setup.py └── xllmx ├── __init__.py ├── data ├── __init__.py ├── conversation │ ├── __init__.py │ └── template.py ├── data_reader.py ├── dataset.py ├── item_processor.py └── sampler.py ├── model ├── __init__.py ├── components.py └── tokenizer.py ├── solvers ├── __init__.py └── finetune │ ├── __init__.py │ └── finetune.py └── util ├── __init__.py ├── ckpt.py ├── dist.py ├── lr_sched.py ├── misc.py └── tensor_type.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Project-specific gitignores 163 | .output 164 | xllmx/output/ 165 | xllmx/output_dir/ 166 | tokenizer.model 167 | *.swp 168 | 169 | .asset/ 170 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | line_length = 120 4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 5 | no_lines_before = STDLIB,LOCALFOLDER 6 | lines_between_types = 1 7 | combine_as_imports = True 8 | force_sort_within_sections = true 9 | order_by_type = True 10 | known_first_party = xllmx 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Exclude all third-party libraries and auto-generated files globally 2 | #exclude: 3 | repos: 4 | # Common hooks 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.6.0 7 | hooks: 8 | - id: check-merge-conflict 9 | - id: check-symlinks 10 | - id: detect-private-key 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | files: (.*\.(py|bzl|md|rst|c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps|cmake|yaml|yml|hook)|BUILD|.*\.BUILD|WORKSPACE|CMakeLists\.txt)$ 14 | # For Python files 15 | - repo: https://github.com/PyCQA/isort 16 | rev: 5.13.2 17 | hooks: 18 | - id: isort 19 | - repo: https://github.com/psf/black.git 20 | rev: 24.4.2 21 | hooks: 22 | - id: black 23 | files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 24 | args: [--line-length=120] 25 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | 2 | ### 1. Basic Setup 3 | 4 | ``` 5 | # Create a new conda environment named 'lumina_mgpt' with Python 3.10 6 | conda create -n lumina_mgpt python=3.10 -y 7 | # Activate the 'lumina_mgpt' environment 8 | conda activate lumina_mgpt 9 | # Install required packages from 'requirements.txt' 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### 2. Install Flash-Attention 14 | ``` 15 | pip install flash-attn --no-build-isolation 16 | ``` 17 | 18 | ### 3. Install xllmx as Python Package 19 | The [xllmx](./xllmx) module is a lightweight engine designed to support the training and inference of 20 | LLM-centered Any2Any models. It is evolved from [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), undergoing comprehensive improvements to achieve higher efficiency and 21 | wider functionality, including the support for flexible arrangement and processing of interleaved media and text. 22 | 23 | The Lumina-mGPT implementation heavily relies on xllmx and requires xllmx to be installed as a python package (**so that `import xllmx` can be used anywhere in your machine, without the restriction of working directory**). 24 | The installation process is as follows: 25 | ```bash 26 | # bash 27 | # go to the root path of the project 28 | cd Lumina_mGPT 29 | # install as package 30 | pip install -e . 31 | ``` 32 | 33 | ### 4. Optional: Install Apex 34 | > [!Caution] 35 | > 36 | > If you merely run inference, there is no need to install Apex. 37 | > 38 | > For training, Apex can bring some training efficiency improvement, but it is still not a must. 39 | > 40 | > Note that training works smoothly with either: 41 | > 1. Apex not installed at all; OR 42 | > 2. Apex successfully installed with CUDA and C++ extensions. 43 | > 44 | > However, it will fail when: 45 | > 1. A Python-only build of Apex is installed. 46 | > 47 | > If errors like `No module named 'fused_layer_norm_cuda'` are reported, it generally means that you are 48 | using a Python-only Apex build. Please run `pip uninstall apex` to remove the build and try again. 49 | 50 | Lumina-mGPT utilizes [apex](https://github.com/NVIDIA/apex) to accelerate training, which needs to be compiled from source. Please follow the [official instructions](https://github.com/NVIDIA/apex#from-source) for installation. 51 | Here are some tips based on our experiences: 52 | 53 | **Step1**: Check the version of CUDA with which your torch is built: 54 | ```python 55 | # python 56 | import torch 57 | print(torch.version.cuda) 58 | ``` 59 | 60 | **Step2**: Check the CUDA toolkit version on your system: 61 | ```bash 62 | # bash 63 | nvcc -V 64 | ``` 65 | **Step3**: If the two aforementioned versions mismatch, or if you do not have CUDA toolkit installed on your system, 66 | please download and install CUDA toolkit from [here](https://developer.nvidia.com/cuda-toolkit-archive) with version matching the torch CUDA version. 67 | 68 | > [!Note] 69 | > 70 | > Note that multiple versions of CUDA toolkit can co-exist on the same machine, and the version can be easily switched by changing the `$PATH` and `$LD_LIBRARY_PATH` environment variables. 71 | There is thus no need to worry about your machine's environment getting messed up. 72 | 73 | **Step4**: You can now start installing apex: 74 | ```bash 75 | git clone https://github.com/NVIDIA/apex 76 | cd apex 77 | # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... 78 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 79 | # otherwise 80 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 81 | ``` 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | # Lumina-mGPT 6 | 7 | A family of multimodal autoregressive models capable of various vision and language tasks, particularly excelling in generating flexible photorealistic images from text descriptions. 👋 join our WeChat 8 | 9 | [![Lumina-mGPT](https://img.shields.io/badge/Paper-Lumina--mGPT-2b9348.svg?logo=arXiv)](https://arxiv.org/abs/2408.02657)  10 | 11 | [![Static Badge](https://img.shields.io/badge/Official(node1)-6B88E3?logo=youtubegaming&label=Demo%20Lumina-mGPT)](http://106.14.2.150:10020/)  12 | [![Static Badge](https://img.shields.io/badge/Official(node2)-6B88E3?logo=youtubegaming&label=Demo%20Lumina-mGPT)](http://106.14.2.150:10021/)  13 | 14 |
15 | 16 | 17 | 18 | ## 📰 News 19 | 20 | - **[2024-08-11] 🎉🎉🎉 [Training codes and documents](./lumina_mgpt/TRAIN.md) are released! 🎉🎉🎉** 21 | 22 | - **[2024-07-08] 🎉🎉🎉 Lumina-mGPT is released! 🎉🎉🎉** 23 | 24 | ## ⚙️ Installation 25 | 26 | See [INSTALL.md](./INSTALL.md) for detailed instructions. 27 | 28 | Note that the Lumina-mGPT implementation heavily relies on 29 | the [xllmx](./xllmx) module, which is evolved from [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory) for supporting 30 | LLM-centered multimodal tasks. Make sure it is installed correctly as a python package before going on. 31 | 32 | ## ⛽ Training 33 | See [lumina_mgpt/TRAIN.md](lumina_mgpt/TRAIN.md) 34 | 35 | ## 📽️ Inference 36 | 37 | > [!Note] 38 | > 39 | > Before using the Lumina-mGPT model, run 40 | > 41 | > ```bash 42 | > # bash 43 | > cd lumina_mgpt 44 | > ``` 45 | > 46 | > to enter the directory of the Lumina-mGPT implementation. 47 | 48 | ### Perpetration 49 | 50 | Since currently the Chameleon implementation in transformers does not contain the VQ-VAE decoder, please manually download the original VQ-VAE weights [provided by Meta](https://github.com/facebookresearch/chameleon) and 51 | put them to the following directory: 52 | 53 | ``` 54 | Lumina-mGPT 55 | - lumina_mgpt/ 56 | - ckpts/ 57 | - chameleon/ 58 | - tokenizer/ 59 | - text_tokenizer.json 60 | - vqgan.yaml 61 | - vqgan.ckpt 62 | - xllmx/ 63 | - ... 64 | ``` 65 | 66 | ### Local Gradio Demos 67 | 68 | We have prepared three different Gradio demos, each showcasing unique functionalities, to help you quickly become familiar with the capabilities of the Lumina-mGPT models. 69 | 70 | #### 1. [demos/demo_image_generation.py](./Lumina-mGPT/demos/demo_image_generation.py) 71 | 72 | This demo is customized for Image Generation tasks, where you can input a text description and generate a corresponding image. 73 | To host this demo, run: 74 | 75 | ```bash 76 | # Note to set the `--target_size` argument consistent with the checkpoint 77 | python -u demos/demo_image_generation.py \ 78 | --pretrained_path Alpha-VLLM/Lumina-mGPT-7B-768 \ 79 | --target_size 768 80 | ``` 81 | 82 | #### 2. [demos/demo_image2image.py](./Lumina-mGPT/demos/demo_image2image.py) 83 | 84 | This demo is designed for models trained with Omni-SFT. you can conveniently switch between the multiple downstream tasks using this demo. 85 | 86 | ```bash 87 | # Note to set the `--target_size` argument consistent with the checkpoint 88 | python -u demos/demo_image2image.py \ 89 | --pretrained_path Alpha-VLLM/Lumina-mGPT-7B-768-Omni \ 90 | --target_size 768 91 | ``` 92 | 93 | #### 3. [demos/demo_freeform.py](./Lumina-mGPT/demos/demo_freeform.py) 94 | 95 | This is a powerful demo with minimal constraint on the input format. It supports flexible interation and is suitable for in-deep exploration. 96 | 97 | ```bash 98 | # Note to set the `--target_size` argument consistent with the checkpoint 99 | python -u demos/demo_freeform.py \ 100 | --pretrained_path Alpha-VLLM/Lumina-mGPT-7B-768-Omni \ 101 | --target_size 768 102 | ``` 103 | 104 | ### Simple Inference 105 | 106 | The simplest code for Lumina-mGPT inference: 107 | 108 | ```python 109 | from inference_solver import FlexARInferenceSolver 110 | from PIL import Image 111 | 112 | # ******************** Image Generation ******************** 113 | inference_solver = FlexARInferenceSolver( 114 | model_path="Alpha-VLLM/Lumina-mGPT-7B-768", 115 | precision="bf16", 116 | target_size=768, 117 | ) 118 | 119 | q1 = f"Generate an image of 768x768 according to the following prompt:\n" 120 | f"Image of a dog playing water, and a waterfall is in the background." 121 | 122 | # generated: tuple of (generated response, list of generated images) 123 | generated = inference_solver.generate( 124 | images=[], 125 | qas=[[q1, None]], 126 | max_gen_len=8192, 127 | temperature=1.0, 128 | logits_processor=inference_solver.create_logits_processor(cfg=4.0, image_top_k=2000), 129 | ) 130 | 131 | a1, new_image = generated[0], generated[1][0] 132 | 133 | 134 | # ******************* Image Understanding ****************** 135 | inference_solver = FlexARInferenceSolver( 136 | model_path="Alpha-VLLM/Lumina-mGPT-7B-512", 137 | precision="bf16", 138 | target_size=512, 139 | ) 140 | 141 | # "<|image|>" symbol will be replaced with sequence of image tokens before fed to LLM 142 | q1 = "Describe the image in detail. <|image|>" 143 | 144 | images = [Image.open("image.png")] 145 | qas = [[q1, None]] 146 | 147 | # `len(images)` should be equal to the number of appearance of "<|image|>" in qas 148 | generated = inference_solver.generate( 149 | images=images, 150 | qas=qas, 151 | max_gen_len=8192, 152 | temperature=1.0, 153 | logits_processor=inference_solver.create_logits_processor(cfg=4.0, image_top_k=2000), 154 | ) 155 | 156 | a1 = generated[0] 157 | # generated[1], namely the list of newly generated images, should typically be empty in this case. 158 | 159 | 160 | # ********************* Omni-Potent ********************* 161 | inference_solver = FlexARInferenceSolver( 162 | model_path="Alpha-VLLM/Lumina-mGPT-7B-768-Omni", 163 | precision="bf16", 164 | target_size=768, 165 | ) 166 | 167 | # Example: Depth Estimation 168 | # For more instructions, see demos/demo_image2image.py 169 | q1 = "Depth estimation. <|image|>" 170 | images = [Image.open("image.png")] 171 | qas = [[q1, None]] 172 | 173 | generated = inference_solver.generate( 174 | images=images, 175 | qas=qas, 176 | max_gen_len=8192, 177 | temperature=1.0, 178 | logits_processor=inference_solver.create_logits_processor(cfg=1.0, image_top_k=200), 179 | ) 180 | 181 | a1 = generated[0] 182 | new_image = generated[1][0] 183 | 184 | ``` 185 | 186 | ## 🤗 Checkpoints 187 | 188 | **Configurations** 189 | 190 | 191 | 192 | 193 | **7B models** 194 | 195 | | Model | Size | Huggingface | 196 | | ------------ | ---- | ---------------------------------------------------------------------------------------- | 197 | | FP-SFT@512 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-512](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-512) | 198 | | FP-SFT@768 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-768](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-768) | 199 | | Omni-SFT@768 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-768-Omni](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-768-Omni) | 200 | | FP-SFT@1024 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-1024](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-1024) | 201 | 202 | **34B models** 203 | 204 | | Model | Size | Huggingface | 205 | | ---------- | ---- | ------------------------------------------------------------------------------------ | 206 | | FP-SFT@512 | 34B | [Alpha-VLLM/Lumina-mGPT-34B-512](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-34B-512) | 207 | 208 | More checkpoints coming soon. 209 | 210 | ## 📑 Open-source Plan 211 | 212 | - [X] Inference code 213 | - [X] Training code 214 | 215 | ## 🔥 Open positions 216 | We are hiring interns, postdocs, and full-time researchers at the General Vision Group, Shanghai AI Lab, with a focus on multi-modality and vision foundation models. If you are interested, please contact gaopengcuhk@gmail.com. 217 | 218 | ## 📄 Citation 219 | 220 | ``` 221 | @misc{liu2024lumina-mgpt, 222 | title={Lumina-mGPT: Illuminate Flexible Photorealistic Text-to-Image Generation with Multimodal Generative Pretraining}, 223 | author={Dongyang Liu and Shitian Zhao and Le Zhuo and Weifeng Lin and Yu Qiao and Hongsheng Li and Peng Gao}, 224 | year={2024}, 225 | eprint={2408.02657}, 226 | archivePrefix={arXiv}, 227 | primaryClass={cs.CV}, 228 | url={https://arxiv.org/abs/2408.02657}, 229 | } 230 | ``` 231 | -------------------------------------------------------------------------------- /assets/config1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/config1.jpg -------------------------------------------------------------------------------- /assets/config2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/config2.jpg -------------------------------------------------------------------------------- /assets/demos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/demos.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/logo.png -------------------------------------------------------------------------------- /lumina_mgpt/.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | line_length = 120 4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 5 | no_lines_before = STDLIB,LOCALFOLDER 6 | lines_between_types = 1 7 | combine_as_imports = True 8 | force_sort_within_sections = true 9 | order_by_type = True 10 | known_first_party = xllmx 11 | -------------------------------------------------------------------------------- /lumina_mgpt/TRAIN.md: -------------------------------------------------------------------------------- 1 | # Lumina-mGPT Training 2 | 3 | For efficiency considerations, the multi-modal datasets are pre-tokenized into sequences of token ids. This leads to significantly faster training 4 | 5 | ## Pre-tokenization 6 | 7 | 8 | ### 1. Run Tokenization 9 | 10 | This stage tokenizes each data point, consisting of interleaved image and text, into a single sequence of integer tokens. After tokenization, the sequence is saved to disk for trainining-time usage. Together with the saved tokens, a json-formatted record file is also generated for indexing all the saved token files. For faster tokenization, you may use multiple GPUs and dispatch different subsets of data to them. 11 | 12 | #### Command: 13 | 14 | ```bash 15 | for i in {0..7} 16 | do 17 | export CUDA_VISIBLE_DEVICES=${i} 18 | python -u pre_tokenize/pre_tokenize.py \ 19 | --splits=8 \ 20 | --rank=${i} \ 21 | --in_filename /path/to/in_filename.json \ 22 | --out_dir /path/to/out_dir \ 23 | --target_size 768 &> ${i}.log & 24 | done 25 | ``` 26 | 27 | #### Format of Input File: 28 | 29 | `in_filename` is expected to be a json file with the following format: 30 | ```python 31 | [ 32 | {...}, 33 | {...}, 34 | { 35 | "conversations":[ 36 | { 37 | "from": "human", 38 | "value": "Hi, please convert this depth image <|image|> to a color image" 39 | }, 40 | { 41 | "from": "gpt", 42 | "value": "<|image|>" 43 | }, 44 | { 45 | "from": "human", 46 | "value": "Could you change its style to that of this image? <|image|>" 47 | }, 48 | { 49 | "from": "gpt", 50 | "value": "Sure, here's the image. <|image|>" 51 | } 52 | ], 53 | "image": ["/path/to/image1.png", "path/to/image2.png", "path/to/image3.png", "path/to/image4.png"] 54 | }, 55 | {...}, 56 | {...} 57 | ] 58 | ``` 59 | 60 | *Rules:* 61 | 62 | 1. The file is a list of dictionaries, and each dictionary represents a data point 63 | 2. Each dictionary contains the key "conversations" 64 | 3. If the conversation involves image(s), the data point should also contain the `image` key, otherwise the `image` key can be omitted 65 | 4. The location of each image should be explicitly specified in the conversation using the `<|image|>` symbol 66 | 1. Apparently, the number of occurrences of the `<|image|>` symbol should be equal to the number of images in the `image` key 67 | 68 | 69 | #### How to adapt to your own format: 70 | 71 | If you have your own data with a different format, you can easily adapt the code to deal with it by modifying the `pre_tokenize.py` file. 72 | We have prepared the space, which is in `ItemProcessor.process_item`, for adding your logic that converts data points of your own format into the standard format. 73 | 74 | ### 2. Concat Records 75 | 76 | After tokenization, You need to concat the record files generated by different processes (GPUs) into one single record file. 77 | **Note that we use the term "record file" to refer to the meta file that contains the information of all the saved token files, 78 | which is different from the token files themselves.** 79 | 80 | ```bash 81 | python -u pre_tokenize/concat_record.py \ 82 | --sub_record_dir /path/to/out_dir \ 83 | --save_path /path/to/out_dir/record.json 84 | ``` 85 | 86 | ## Training 87 | 88 | #### Command: 89 | We provide an example experiment scripts [exps/7B.sh](exps/7B.sh) for training the 7B model. Suppose you have access to a SLURM clsuter, you can run the following command to start training: 90 | 91 | ```bash 92 | srun -n32 --ntasks-per-node=8 --gres=gpu:8 bash exps/7B.sh 93 | ``` 94 | 95 | Otherwise, if your want to use `torchrun` for distributed training, you can make the following change to `exps/7B.sh`: 96 | ```bash 97 | # python -u finetune_solver.py \ 98 | torchrun --torchrun_kwargs finetune_solver.py \ 99 | ``` 100 | 101 | #### About the `--data_config` argument: 102 | The ``--data_config`` argument should point to a `*.yaml` file, which is a meta file that gathers one or multiple record files. 103 | In other words, you may pre-tokenize multiple datasets independently and list the record files in the same data config file 104 | for joint training. 105 | -------------------------------------------------------------------------------- /lumina_mgpt/configs/data/sample.yaml: -------------------------------------------------------------------------------- 1 | META: 2 | - path: 'path/to/record1.json' 3 | - path: 'path/to/record2.json' 4 | ratio: 0.4 # In each epoch, sample only 40% of the data from this dataset 5 | - path: 'path/to/record3.json' 6 | -------------------------------------------------------------------------------- /lumina_mgpt/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/lumina_mgpt/data/__init__.py -------------------------------------------------------------------------------- /lumina_mgpt/data/convertsation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class Conversation: 5 | sep_token = "" 6 | roles = ["Human", "Assistant"] 7 | 8 | def __init__(self, messages=None): 9 | self.messages = messages or [] 10 | 11 | def process(self): 12 | ret = "" 13 | pieces = [] 14 | for i, (role, message) in enumerate(self.messages): 15 | if message is not None: 16 | turn = message + self.sep_token 17 | ret += turn 18 | if role == self.roles[1]: 19 | pieces.append({"data": turn, "predict": True}) 20 | else: 21 | pieces.append({"data": turn, "predict": False}) 22 | else: 23 | # generation prompt 24 | assert i == len(self.messages) - 1 and role == self.roles[1], "only last assistant message can be None" 25 | 26 | result = { 27 | "conv": ret, # text involving the complete conversation 28 | "pieces": pieces, # list to help correctly mark the labels 29 | } 30 | return result 31 | 32 | def get_prompt(self): 33 | return self.process()["conv"] 34 | 35 | def append_message(self, role, message): 36 | self.messages.append([role, message]) 37 | 38 | def copy(self): 39 | return Conversation( 40 | messages=[[x, y] for x, y in self.messages], 41 | ) 42 | 43 | def load_qas(self, qas: List[List[str]]): 44 | """ 45 | convert the list of question-answer pairs to a string, which contains the conversation involving all 46 | the questions and answers. When the last answer is None, the returned string is the prompt which 47 | can be used by the model to generate the last answer. 48 | :param qas: [[question1, answer1], [question2, answer2], ..., [questionX, answerX]] 49 | note that the last answer, i.e. answerX, can be None 50 | :return: the prompt 51 | """ 52 | self.messages = [] 53 | for q, a in qas: 54 | self.append_message(self.roles[0], q) 55 | self.append_message(self.roles[1], a) 56 | -------------------------------------------------------------------------------- /lumina_mgpt/data/item_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | from typing import Dict, List 5 | 6 | from PIL import Image 7 | import torch 8 | 9 | from data.convertsation import Conversation 10 | import model.chameleon_vae_ori as chameleon_vae_ori 11 | from xllmx.data.data_reader import read_general 12 | from xllmx.data.item_processor import MMConvItemProcessor 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def center_crop(pil_image, crop_size): 18 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]: 19 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 20 | 21 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1]) 22 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 23 | 24 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0]) 25 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1]) 26 | crop_right = crop_left + crop_size[0] 27 | crop_lower = crop_upper + crop_size[1] 28 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower)) 29 | 30 | 31 | def var_center_crop(pil_image, crop_size_list, random_top_k=1): 32 | w, h = pil_image.size 33 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list] 34 | crop_size = random.choice( 35 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k] 36 | )[1] 37 | return center_crop(pil_image, crop_size) 38 | 39 | 40 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0): 41 | assert max_ratio >= 1.0 42 | crop_size_list = [] 43 | wp, hp = num_patches, 1 44 | while wp > 0: 45 | if max(wp, hp) / min(wp, hp) <= max_ratio: 46 | crop_size_list.append((wp * patch_size, hp * patch_size)) 47 | if (hp + 1) * wp <= num_patches: 48 | hp += 1 49 | else: 50 | wp -= 1 51 | return crop_size_list 52 | 53 | 54 | class FlexARItemProcessor(MMConvItemProcessor): 55 | image_start_token = "" # fixed tokens for start and end, so can hardcode 56 | image_end_token = "" 57 | full_sub_sep_token = "" 58 | sub_sub_sep_token = "" 59 | sub_skip_token = "" 60 | new_line_token = "" 61 | 62 | def __init__( 63 | self, 64 | tokenizer="Alpha-VLLM/Lumina-mGPT-7B-768", 65 | conv_template=Conversation, 66 | target_size=512, 67 | ): 68 | 69 | super().__init__( 70 | { 71 | "<|image|>": self.process_image, 72 | }, 73 | ["<|image|>"], 74 | tokenizer, 75 | conv_template, 76 | ) 77 | 78 | self.patch_size = 32 79 | self.crop_size_list = generate_crop_size_list((target_size // self.patch_size) ** 2, self.patch_size) 80 | logger.info("List of crop sizes:") 81 | for i in range(0, len(self.crop_size_list), 6): 82 | logger.info(" " + "".join([f"{f'{w} x {h}':14s}" for w, h in self.crop_size_list[i : i + 6]])) 83 | 84 | # todo 85 | # currently still use the original image tokenizer provided by Meta rather than transformers 86 | # because the transformers implementation does not contain the vae decoder 87 | self.chameleon_ori_vocab = chameleon_vae_ori.VocabInfo( 88 | json.load(open("./ckpts/chameleon/tokenizer/text_tokenizer.json", encoding="utf8"))["model"]["vocab"] 89 | ) 90 | self.chameleon_ori_translation = chameleon_vae_ori.VocabTranslation(self.chameleon_ori_vocab, device="cuda") 91 | self.chameleon_ori_image_tokenizer = chameleon_vae_ori.ImageTokenizer( 92 | cfg_path="./ckpts/chameleon/tokenizer/vqgan.yaml", 93 | ckpt_path="./ckpts/chameleon/tokenizer/vqgan.ckpt", 94 | device="cuda", 95 | ) 96 | 97 | @staticmethod 98 | def get_n_grids_token(n_grids): 99 | return f"" 100 | 101 | def token2id(self, token: str) -> int: 102 | return self.tokenizer.tokenizer.vocab[token] 103 | 104 | @torch.no_grad() 105 | def process_image(self, image) -> Dict: 106 | if isinstance(image, Image.Image): 107 | pass 108 | else: 109 | image = Image.open(read_general(image)) 110 | 111 | image = var_center_crop(image, crop_size_list=self.crop_size_list) 112 | 113 | w_grids, h_grids = image.size[0] // self.patch_size, image.size[1] // self.patch_size 114 | 115 | image_toks = self.chameleon_ori_translation.convert_img2bp2( 116 | self.chameleon_ori_image_tokenizer.img_tokens_from_pil(image) 117 | ).view(-1) 118 | 119 | full_image_toks = image_toks.reshape(image.size[1] // 16, image.size[0] // 16) 120 | new_line_id = self.token2id(self.new_line_token) 121 | 122 | full_image_toks = torch.cat( 123 | ( 124 | full_image_toks, 125 | torch.ones(image.size[1] // 16, 1, device=full_image_toks.device, dtype=full_image_toks.dtype) 126 | * new_line_id, 127 | ), 128 | dim=1, 129 | ).flatten() 130 | 131 | result_toks = [ 132 | self.token2id(self.image_start_token), 133 | self.token2id(self.get_n_grids_token(h_grids)), 134 | self.token2id(self.get_n_grids_token(w_grids)), 135 | *full_image_toks.tolist(), 136 | self.token2id(self.image_end_token), 137 | ] 138 | 139 | return {"input_ids": result_toks, "labels": result_toks} 140 | 141 | def process_item(self, item, training_mode=False, out_flatten=True): 142 | if not out_flatten: 143 | return super().process_item(item, training_mode=training_mode) 144 | 145 | if training_mode: 146 | tokens, labels = super().process_item(item, training_mode=training_mode) 147 | input_tokens_item = [] 148 | modified_labels_item = [] 149 | for i, (token_or_media, ori_label) in enumerate(zip(tokens, labels)): 150 | if isinstance(token_or_media, int): 151 | token = token_or_media 152 | input_tokens_item.append(token) 153 | modified_labels_item.append(ori_label) 154 | else: 155 | input_tokens_item += token_or_media["input_ids"] 156 | if ori_label <= 0: # in the prompt part 157 | modified_labels_item += [-100] * len(token_or_media["input_ids"]) 158 | else: 159 | modified_labels_item += token_or_media["labels"] 160 | 161 | return input_tokens_item, modified_labels_item 162 | else: 163 | tokens = super().process_item(item, training_mode=training_mode) 164 | input_tokens_item = [] 165 | for i, token_or_media in enumerate(tokens): 166 | if isinstance(token_or_media, int): 167 | input_tokens_item.append(token_or_media) 168 | else: 169 | input_tokens_item += token_or_media["input_ids"] 170 | 171 | return input_tokens_item 172 | 173 | def decode_image(self, tokens: List[int]) -> Image.Image: 174 | if tokens[0] == self.token2id(self.image_start_token): 175 | tokens = tokens[1:] 176 | if tokens[-1] == self.token2id(self.image_end_token): 177 | tokens = tokens[:-1] 178 | 179 | h_grids, w_grids = tokens[0] - 8804, tokens[1] - 8804 180 | tokens = tokens[2:] 181 | h, w = h_grids * self.patch_size, w_grids * self.patch_size 182 | h_latent_dim, w_latent_dim = h_grids * 2, w_grids * 2 183 | 184 | for i in range(len(tokens)): 185 | if (i + 1) % (w_latent_dim + 1) != 0: 186 | tokens[i] = self.chameleon_ori_translation.bpe2img[tokens[i]] 187 | 188 | assert len(tokens) == h_latent_dim * (w_latent_dim + 1) 189 | tokens = torch.tensor(tokens, dtype=torch.int64).cuda() 190 | 191 | tokens = tokens.view(h_latent_dim, w_latent_dim + 1)[:, :-1].flatten() 192 | 193 | return self.chameleon_ori_image_tokenizer.pil_from_img_toks(tokens, h_latent_dim, w_latent_dim) 194 | -------------------------------------------------------------------------------- /lumina_mgpt/demos/demo_freeform.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | import argparse 7 | import builtins 8 | import datetime 9 | import multiprocessing as mp 10 | import traceback 11 | from typing import List, Optional 12 | 13 | import gradio as gr 14 | import torch 15 | 16 | from inference_solver import FlexARInferenceSolver 17 | from xllmx.util.misc import random_seed 18 | 19 | 20 | class Ready: 21 | pass 22 | 23 | 24 | class ModelFailure: 25 | pass 26 | 27 | 28 | def model_worker( 29 | rank: int, 30 | args: argparse.Namespace, 31 | barrier: mp.Barrier, 32 | request_queue: mp.Queue, 33 | response_queue: Optional[mp.Queue] = None, 34 | ) -> None: 35 | """ 36 | The worker function that manipulates the GPU to run the inference. 37 | Exact n_gpu workers are started, with each one operating on a separate GPU. 38 | 39 | Args: 40 | rank (int): Distributed rank of the worker. 41 | args (argparse.Namespace): All command line arguments. 42 | barrier (multiprocessing.Barrier): A barrier used to delay the start 43 | of Web UI to be after the start of the model. 44 | """ 45 | 46 | builtin_print = builtins.print 47 | 48 | def print(*args, **kwargs): 49 | kwargs["flush"] = True 50 | now = datetime.datetime.now().time() 51 | builtin_print("[{}] ".format(now), end="") # print with time stamp 52 | builtin_print(*args, **kwargs) 53 | 54 | builtins.print = print 55 | 56 | world_size = len(args.gpu_ids) 57 | gpu_id = args.gpu_ids[rank] 58 | # dist.init_process_group( 59 | # backend="nccl", rank=rank, world_size=world_size, 60 | # init_method=f"tcp://{args.master_addr}:{args.master_port}", 61 | # ) 62 | # print(f"| distributed init on worker {rank}/{world_size}. " 63 | # f"using gpu: {gpu_id}") 64 | torch.cuda.set_device(gpu_id) 65 | 66 | inference_solver = FlexARInferenceSolver( 67 | model_path=args.pretrained_path, precision=args.precision, target_size=args.target_size 68 | ) 69 | 70 | barrier.wait() 71 | 72 | while True: 73 | if response_queue is not None: 74 | response_queue.put(Ready()) 75 | try: 76 | existing_images, chatbot, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k = request_queue.get() 77 | 78 | print(chatbot) 79 | 80 | random_seed(seed=seed) 81 | 82 | generated = inference_solver.generate( 83 | existing_images, 84 | chatbot, 85 | max_gen_len, 86 | gen_t, 87 | logits_processor=inference_solver.create_logits_processor( 88 | cfg=cfg, text_top_k=text_top_k, image_top_k=image_top_k 89 | ), 90 | ) 91 | 92 | stream_response = {"text": generated[0], "image": generated[1], "end_of_content": True} 93 | print(generated[1]) 94 | if response_queue is not None: 95 | response_queue.put(stream_response) 96 | 97 | except Exception: 98 | print(traceback.format_exc()) 99 | response_queue.put(ModelFailure()) 100 | 101 | 102 | def gradio_worker( 103 | request_queues: List[mp.Queue], 104 | response_queue: mp.Queue, 105 | args: argparse.Namespace, 106 | barrier: mp.Barrier, 107 | ) -> None: 108 | """ 109 | The gradio worker is responsible for displaying the WebUI and relay the 110 | requests to model workers. It should be launched only once. 111 | 112 | Args: 113 | request_queues (List[mp.Queue]): A list of request queues (one for 114 | each model worker). 115 | args (argparse.Namespace): All command line arguments. 116 | barrier (multiprocessing.Barrier): A barrier used to delay the start 117 | of Web UI to be after the start of the model. 118 | """ 119 | 120 | def check_input_sanity(text_input: str, new_images): 121 | if new_images is None: 122 | new_images = [] 123 | 124 | print(new_images) 125 | 126 | if text_input.count("<|image|>") != len(new_images): 127 | raise gr.Error("please make sure that you have the same number of image inputs and <|image|> tokens") 128 | 129 | def show_user_input(text_input, new_images, chatbot, chatbot_display, existing_images): 130 | 131 | existing_images = [] if existing_images is None else existing_images 132 | new_images = [] if new_images is None else new_images 133 | 134 | return ( 135 | "", 136 | [], 137 | chatbot + [[text_input, None]], 138 | chatbot_display + [[text_input, None]], 139 | existing_images + new_images, 140 | ) 141 | 142 | def stream_model_output( 143 | existing_images, chatbot, chatbot_display, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k 144 | ): 145 | 146 | existing_images = [] if existing_images is None else existing_images 147 | 148 | while True: 149 | content_piece = response_queue.get() 150 | if isinstance(content_piece, Ready): 151 | break 152 | for queue in request_queues: 153 | queue.put( 154 | ([_[0] for _ in existing_images], chatbot, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k) 155 | ) 156 | while True: 157 | content_piece = response_queue.get() 158 | if isinstance(content_piece, ModelFailure): 159 | raise RuntimeError 160 | chatbot_display[-1][1] = content_piece["text"].replace("<", "<").replace(">", ">") 161 | if content_piece["end_of_content"]: 162 | chatbot[-1][1] = content_piece["text"] 163 | chatbot_display[-1][1] = content_piece["text"] 164 | yield chatbot, chatbot_display, existing_images + [(_, None) for _ in content_piece["image"]] 165 | break 166 | else: 167 | yield chatbot, chatbot_display, [] 168 | 169 | def clear(): 170 | chatbot = [] 171 | chatbot_display = [] 172 | text_input = "" 173 | return chatbot, chatbot_display, text_input 174 | 175 | with gr.Blocks(css="#image_input {height: 100% !important}") as demo: 176 | gr.Markdown("# Lumina-mGPT Demo\n") 177 | with gr.Row() as r: 178 | with gr.Column(scale=1): 179 | existing_images = gr.Gallery(value=[], label="Existing Images", interactive=False) 180 | chatbot = gr.Chatbot(visible=False) 181 | chatbot_display = gr.Chatbot() 182 | with gr.Column(scale=1): 183 | new_images = gr.Gallery(value=[], label="Image Inputs", interactive=True) 184 | text_input = gr.Textbox() 185 | submit_button = gr.Button("Submit", variant="primary") 186 | clear_button = gr.ClearButton([existing_images, chatbot, chatbot_display, text_input, new_images]) 187 | with gr.Row(): 188 | with gr.Column(scale=1): 189 | max_gen_len = gr.Slider( 190 | minimum=1, 191 | maximum=5000, 192 | value=2048, 193 | interactive=True, 194 | label="max new tokens", 195 | ) 196 | with gr.Column(scale=1): 197 | seed = gr.Slider( 198 | minimum=0, 199 | maximum=int(1e5), 200 | value=1, 201 | step=1, 202 | interactive=True, 203 | label="Seed (0 for random)", 204 | ) 205 | with gr.Row(): 206 | with gr.Column(scale=1): 207 | gen_t = gr.Slider( 208 | minimum=0.0, 209 | maximum=4.0, 210 | value=1.0, 211 | interactive=True, 212 | label="gen_t", 213 | ) 214 | with gr.Column(scale=1): 215 | cfg = gr.Slider( 216 | minimum=0.0, 217 | maximum=16.0, 218 | value=1.0, 219 | interactive=True, 220 | label="cfg", 221 | ) 222 | with gr.Row(): 223 | with gr.Column(scale=1): 224 | image_top_k = gr.Slider( 225 | minimum=0, 226 | maximum=8192, 227 | value=2000, 228 | interactive=True, 229 | label="Image Top-k", 230 | ) 231 | with gr.Column(scale=1): 232 | text_top_k = gr.Slider( 233 | minimum=0, 234 | maximum=9999, 235 | value=5, 236 | interactive=True, 237 | label="Text Top-k", 238 | ) 239 | 240 | text_input.submit(check_input_sanity, [text_input, new_images], []).success( 241 | show_user_input, 242 | [text_input, new_images, chatbot, chatbot_display, existing_images], 243 | [text_input, new_images, chatbot, chatbot_display, existing_images], 244 | ).success( 245 | stream_model_output, 246 | [existing_images, chatbot, chatbot_display, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k], 247 | [chatbot, chatbot_display, existing_images], 248 | ) 249 | submit_button.click(check_input_sanity, [text_input, new_images], []).success( 250 | show_user_input, 251 | [text_input, new_images, chatbot, chatbot_display, existing_images], 252 | [text_input, new_images, chatbot, chatbot_display, existing_images], 253 | ).success( 254 | stream_model_output, 255 | [existing_images, chatbot, chatbot_display, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k], 256 | [chatbot, chatbot_display, existing_images], 257 | ) 258 | barrier.wait() 259 | demo.queue(api_open=True).launch( 260 | share=True, 261 | server_name="0.0.0.0", 262 | ) 263 | 264 | 265 | if __name__ == "__main__": 266 | parser = argparse.ArgumentParser("X-LLM-X Chat Demo") 267 | group = parser.add_mutually_exclusive_group() 268 | group.add_argument( 269 | "--gpu_ids", 270 | type=int, 271 | nargs="+", 272 | help="A list of space-separated gpu ids to run the model on. " 273 | "The model will span across GPUs in tensor-parallel mode.", 274 | ) 275 | group.add_argument( 276 | "--n_gpus", 277 | type=int, 278 | default=1, 279 | help="Number of GPUs to run the model on. Equivalent to " "--gpu_ids 0 1 2 ... n-1", 280 | ) 281 | parser.add_argument("--pretrained_path", type=str, required=True, help="Path to the model checkpoints.") 282 | parser.add_argument( 283 | "--precision", 284 | type=str, 285 | choices=["fp16", "bf16"], 286 | default="bf16", 287 | help="The dtype used for model weights and inference.", 288 | ) 289 | parser.add_argument( 290 | "--target_size", type=int, default=768, choices=[512, 768, 1024], help="The target image generation size." 291 | ) 292 | args = parser.parse_args() 293 | 294 | # check and setup gpu_ids to use 295 | if args.gpu_ids is None: 296 | if args.n_gpus is None: 297 | args.n_gpus = 1 298 | assert args.n_gpus > 0, "The demo currently must run on a positive number of GPUs." 299 | args.gpu_ids = list(range(args.n_gpus)) 300 | 301 | assert len(args.gpu_ids) == 1, "Currently only supports running on a single GPU." 302 | 303 | # using the default "fork" method messes up some imported libs (e.g., 304 | # pandas) 305 | mp.set_start_method("spawn") 306 | 307 | # setup the queues and start the model workers 308 | request_queues = [] 309 | response_queue = mp.Queue() 310 | worker_processes = [] 311 | barrier = mp.Barrier(len(args.gpu_ids) + 1) 312 | for rank, gpu_id in enumerate(args.gpu_ids): 313 | request_queue = mp.Queue() 314 | rank_response_queue = response_queue if rank == 0 else None 315 | process = mp.Process( 316 | target=model_worker, 317 | args=(rank, args, barrier, request_queue, rank_response_queue), 318 | ) 319 | process.start() 320 | worker_processes.append(process) 321 | request_queues.append(request_queue) 322 | 323 | gradio_worker(request_queues, response_queue, args, barrier) 324 | -------------------------------------------------------------------------------- /lumina_mgpt/demos/demo_image2image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | import argparse 7 | import builtins 8 | import datetime 9 | import multiprocessing as mp 10 | import traceback 11 | from typing import List, Optional 12 | 13 | import gradio as gr 14 | import torch 15 | 16 | from inference_solver import FlexARInferenceSolver 17 | from xllmx.util.misc import random_seed 18 | 19 | 20 | class Ready: 21 | pass 22 | 23 | 24 | class ModelFailure: 25 | pass 26 | 27 | 28 | def model_worker( 29 | rank: int, 30 | args: argparse.Namespace, 31 | barrier: mp.Barrier, 32 | request_queue: mp.Queue, 33 | response_queue: Optional[mp.Queue] = None, 34 | ) -> None: 35 | """ 36 | The worker function that manipulates the GPU to run the inference. 37 | Exact n_gpu workers are started, with each one operating on a separate GPU. 38 | 39 | Args: 40 | rank (int): Distributed rank of the worker. 41 | args (argparse.Namespace): All command line arguments. 42 | barrier (multiprocessing.Barrier): A barrier used to delay the start 43 | of Web UI to be after the start of the model. 44 | """ 45 | 46 | builtin_print = builtins.print 47 | 48 | def print(*args, **kwargs): 49 | kwargs["flush"] = True 50 | now = datetime.datetime.now().time() 51 | builtin_print("[{}] ".format(now), end="") # print with time stamp 52 | builtin_print(*args, **kwargs) 53 | 54 | builtins.print = print 55 | 56 | world_size = len(args.gpu_ids) 57 | gpu_id = args.gpu_ids[rank] 58 | # dist.init_process_group( 59 | # backend="nccl", rank=rank, world_size=world_size, 60 | # init_method=f"tcp://{args.master_addr}:{args.master_port}", 61 | # ) 62 | # print(f"| distributed init on worker {rank}/{world_size}. " 63 | # f"using gpu: {gpu_id}") 64 | torch.cuda.set_device(gpu_id) 65 | 66 | inference_solver = FlexARInferenceSolver( 67 | model_path=args.pretrained_path, precision=args.precision, target_size=args.target_size 68 | ) 69 | 70 | barrier.wait() 71 | 72 | while True: 73 | if response_queue is not None: 74 | response_queue.put(Ready()) 75 | try: 76 | prompt, input_image, task, seed, gen_t, cfg, image_top_k = request_queue.get() 77 | 78 | random_seed(seed=seed) 79 | 80 | if task == "Semantic Segmentation": 81 | prompt = "Segmentic segmentation." 82 | elif task == "Depth Estimation": 83 | prompt = "Depth estimation." 84 | elif task == "Surface Norm Estimation": 85 | prompt = "Surface normal estimation." 86 | elif task == "Human Pose Estimation": 87 | prompt = "Human pose estimation." 88 | elif task == "Detection": 89 | prompt = "Detect: " + prompt 90 | elif task == "Editing": 91 | pass 92 | elif task == "Condition to Image": 93 | prompt = f"Generate an image according to the provided image, and according to the following caption:\n{prompt}" # noqa 94 | 95 | prompt = prompt + " <|image|>" 96 | print(prompt) 97 | 98 | generated = inference_solver.generate( 99 | [input_image], 100 | [[prompt, None]], 101 | 5000, 102 | gen_t, 103 | logits_processor=inference_solver.create_logits_processor( 104 | cfg=cfg, text_top_k=5, image_top_k=image_top_k 105 | ), 106 | ) 107 | 108 | stream_response = {"text": generated[0], "image": generated[1], "prompt": prompt, "end_of_content": True} 109 | 110 | print(generated[1]) 111 | 112 | if response_queue is not None: 113 | print("here") 114 | response_queue.put(stream_response) 115 | 116 | except Exception: 117 | print(traceback.format_exc()) 118 | response_queue.put(ModelFailure()) 119 | 120 | 121 | def gradio_worker( 122 | request_queues: List[mp.Queue], 123 | response_queue: mp.Queue, 124 | args: argparse.Namespace, 125 | barrier: mp.Barrier, 126 | ) -> None: 127 | """ 128 | The gradio worker is responsible for displaying the WebUI and relay the 129 | requests to model workers. It should be launched only once. 130 | 131 | Args: 132 | request_queues (List[mp.Queue]): A list of request queues (one for 133 | each model worker). 134 | args (argparse.Namespace): All command line arguments. 135 | barrier (multiprocessing.Barrier): A barrier used to delay the start 136 | of Web UI to be after the start of the model. 137 | """ 138 | 139 | def check_input_sanity(text_input: str): 140 | if len(text_input) > 512: 141 | raise gr.Error("please do not send more than 1024 characters to this demo") 142 | if text_input.count("<|image|>") != 0: 143 | raise gr.Error("please do not send <|image|> tokens to this demo") 144 | 145 | def updateUIForTask(task): 146 | if task == "Detection": 147 | return gr.update(label="Object to Detect", visible=True) 148 | elif task == "Editing": 149 | return gr.update(label="Editing Prompt", visible=True) 150 | elif task == "Condition to Image": 151 | return gr.update(label="Image Prompt", visible=True) 152 | else: 153 | return gr.update(visible=False, value="") 154 | 155 | def stream_model_output(prompt, input_image, task, seed, gen_t, cfg, image_top_k): 156 | 157 | while True: 158 | content_piece = response_queue.get() 159 | if isinstance(content_piece, Ready): 160 | break 161 | for queue in request_queues: 162 | queue.put((prompt, input_image, task, seed, gen_t, cfg, image_top_k)) 163 | while True: 164 | content_piece = response_queue.get() 165 | if isinstance(content_piece, ModelFailure): 166 | raise RuntimeError 167 | if content_piece["end_of_content"]: 168 | yield content_piece["image"][0], content_piece["prompt"] 169 | break 170 | else: 171 | yield None, None 172 | 173 | with gr.Blocks(css="#image_input {height: 100% !important}") as demo: 174 | gr.Markdown("# Lumina-mGPT Image2Image Demo\n") 175 | with gr.Row() as r: 176 | with gr.Column(scale=1): 177 | with gr.Row(): 178 | input_image = gr.Image(type="pil", label="Input Image", interactive=True, elem_id="image_input") 179 | with gr.Row(): 180 | l_tasks = [ 181 | "Semantic Segmentation", 182 | "Depth Estimation", 183 | "Surface Norm Estimation", 184 | "Detection", 185 | "Human Pose Estimation", 186 | "Editing", 187 | "Condition to Image", 188 | ] 189 | task = gr.Dropdown(value=f"Semantic Segmentation", choices=l_tasks, label="tasks") 190 | with gr.Row(): 191 | prompt = gr.Textbox(visible=False) 192 | with gr.Row(): 193 | with gr.Column(scale=1): 194 | seed = gr.Slider( 195 | minimum=0, 196 | maximum=int(1e5), 197 | value=1, 198 | step=1, 199 | interactive=True, 200 | label="Seed (0 for random)", 201 | ) 202 | with gr.Column(scale=1): 203 | gen_t = gr.Slider( 204 | minimum=0.0, 205 | maximum=4.0, 206 | value=1.0, 207 | interactive=True, 208 | label="gen_t", 209 | ) 210 | with gr.Row(): 211 | with gr.Column(scale=1): 212 | cfg = gr.Slider( 213 | minimum=0.0, 214 | maximum=16.0, 215 | value=1.0, 216 | interactive=True, 217 | label="cfg", 218 | ) 219 | with gr.Column(scale=1): 220 | image_top_k = gr.Slider( 221 | minimum=0, 222 | maximum=8192, 223 | value=200, 224 | interactive=True, 225 | label="Image Top-k", 226 | ) 227 | with gr.Row(): 228 | submit_button = gr.Button("Submit", variant="primary") 229 | 230 | with gr.Column(): 231 | output_img = gr.Image( 232 | label="Generated image", 233 | interactive=False, 234 | ) 235 | real_prompt = gr.Textbox( 236 | label="Real Prompt", 237 | interactive=False, 238 | visible=False, 239 | lines=5, 240 | show_label=True, 241 | show_copy_button=True, 242 | ) 243 | 244 | task.change(updateUIForTask, task, prompt) 245 | 246 | submit_button.click(check_input_sanity, [prompt], []).success( 247 | stream_model_output, [prompt, input_image, task, seed, gen_t, cfg, image_top_k], [output_img, real_prompt] 248 | ).success(lambda: gr.update(visible=True), [], [real_prompt]) 249 | barrier.wait() 250 | demo.queue(api_open=True).launch( 251 | share=True, 252 | server_name="0.0.0.0", 253 | ) 254 | 255 | 256 | if __name__ == "__main__": 257 | parser = argparse.ArgumentParser("X-LLM-X Chat Demo") 258 | group = parser.add_mutually_exclusive_group() 259 | group.add_argument( 260 | "--gpu_ids", 261 | type=int, 262 | nargs="+", 263 | help="A list of space-separated gpu ids to run the model on. " 264 | "The model will span across GPUs in tensor-parallel mode.", 265 | ) 266 | group.add_argument( 267 | "--n_gpus", 268 | type=int, 269 | default=1, 270 | help="Number of GPUs to run the model on. Equivalent to " "--gpu_ids 0 1 2 ... n-1", 271 | ) 272 | parser.add_argument("--pretrained_path", type=str, required=True, help="Path to the model checkpoints.") 273 | parser.add_argument( 274 | "--precision", 275 | type=str, 276 | choices=["fp16", "bf16"], 277 | default="bf16", 278 | help="The dtype used for model weights and inference.", 279 | ) 280 | parser.add_argument( 281 | "--target_size", type=int, default=768, choices=[512, 768, 1024], help="The target image generation size." 282 | ) 283 | args = parser.parse_args() 284 | 285 | # check and setup gpu_ids to use 286 | if args.gpu_ids is None: 287 | if args.n_gpus is None: 288 | args.n_gpus = 1 289 | assert args.n_gpus > 0, "The demo currently must run on a positive number of GPUs." 290 | args.gpu_ids = list(range(args.n_gpus)) 291 | 292 | assert len(args.gpu_ids) == 1, "Currently only supports running on a single GPU." 293 | 294 | # using the default "fork" method messes up some imported libs (e.g., 295 | # pandas) 296 | mp.set_start_method("spawn") 297 | 298 | # setup the queues and start the model workers 299 | request_queues = [] 300 | response_queue = mp.Queue() 301 | worker_processes = [] 302 | barrier = mp.Barrier(len(args.gpu_ids) + 1) 303 | for rank, gpu_id in enumerate(args.gpu_ids): 304 | request_queue = mp.Queue() 305 | rank_response_queue = response_queue if rank == 0 else None 306 | process = mp.Process( 307 | target=model_worker, 308 | args=(rank, args, barrier, request_queue, rank_response_queue), 309 | ) 310 | process.start() 311 | worker_processes.append(process) 312 | request_queues.append(request_queue) 313 | 314 | gradio_worker(request_queues, response_queue, args, barrier) 315 | -------------------------------------------------------------------------------- /lumina_mgpt/demos/demo_image_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | import argparse 7 | import builtins 8 | import datetime 9 | import multiprocessing as mp 10 | import traceback 11 | from typing import List, Optional 12 | 13 | import gradio as gr 14 | import torch 15 | 16 | from data.item_processor import generate_crop_size_list 17 | from inference_solver import FlexARInferenceSolver 18 | from xllmx.util.misc import random_seed 19 | 20 | 21 | class Ready: 22 | pass 23 | 24 | 25 | class ModelFailure: 26 | pass 27 | 28 | 29 | @torch.no_grad() 30 | def model_worker( 31 | rank: int, 32 | args: argparse.Namespace, 33 | barrier: mp.Barrier, 34 | request_queue: mp.Queue, 35 | response_queue: Optional[mp.Queue] = None, 36 | ) -> None: 37 | """ 38 | The worker function that manipulates the GPU to run the inference. 39 | Exact n_gpu workers are started, with each one operating on a separate GPU. 40 | 41 | Args: 42 | rank (int): Distributed rank of the worker. 43 | args (argparse.Namespace): All command line arguments. 44 | barrier (multiprocessing.Barrier): A barrier used to delay the start 45 | of Web UI to be after the start of the model. 46 | """ 47 | 48 | builtin_print = builtins.print 49 | 50 | def print(*args, **kwargs): 51 | kwargs["flush"] = True 52 | now = datetime.datetime.now().time() 53 | builtin_print("[{}] ".format(now), end="") # print with time stamp 54 | builtin_print(*args, **kwargs) 55 | 56 | builtins.print = print 57 | 58 | world_size = len(args.gpu_ids) 59 | gpu_id = args.gpu_ids[rank] 60 | # dist.init_process_group( 61 | # backend="nccl", rank=rank, world_size=world_size, 62 | # init_method=f"tcp://{args.master_addr}:{args.master_port}", 63 | # ) 64 | # print(f"| distributed init on worker {rank}/{world_size}. " 65 | # f"using gpu: {gpu_id}") 66 | torch.cuda.set_device(gpu_id) 67 | 68 | inference_solver = FlexARInferenceSolver( 69 | model_path=args.pretrained_path, 70 | precision=args.precision, 71 | ) 72 | 73 | barrier.wait() 74 | 75 | while True: 76 | if response_queue is not None: 77 | response_queue.put(Ready()) 78 | try: 79 | prompt, resolution, seed, gen_t, cfg, image_top_k = request_queue.get() 80 | 81 | random_seed(seed=seed) 82 | 83 | prompt = f"Generate an image of {resolution} according to the following prompt:\n{prompt}" 84 | print(prompt) 85 | 86 | generated = inference_solver.generate( 87 | [], 88 | [[prompt, None]], 89 | 5000, 90 | gen_t, 91 | logits_processor=inference_solver.create_logits_processor( 92 | cfg=cfg, text_top_k=5, image_top_k=image_top_k 93 | ), 94 | ) 95 | 96 | print("*" * 100) 97 | print(generated[1]) 98 | 99 | stream_response = {"text": generated[0], "image": generated[1], "prompt": prompt, "end_of_content": True} 100 | 101 | print(generated[1]) 102 | 103 | if response_queue is not None: 104 | print("here") 105 | response_queue.put(stream_response) 106 | 107 | except Exception: 108 | print(traceback.format_exc()) 109 | response_queue.put(ModelFailure()) 110 | 111 | 112 | def gradio_worker( 113 | request_queues: List[mp.Queue], 114 | response_queue: mp.Queue, 115 | args: argparse.Namespace, 116 | barrier: mp.Barrier, 117 | ) -> None: 118 | """ 119 | The gradio worker is responsible for displaying the WebUI and relay the 120 | requests to model workers. It should be launched only once. 121 | 122 | Args: 123 | request_queues (List[mp.Queue]): A list of request queues (one for 124 | each model worker). 125 | args (argparse.Namespace): All command line arguments. 126 | barrier (multiprocessing.Barrier): A barrier used to delay the start 127 | of Web UI to be after the start of the model. 128 | """ 129 | 130 | def check_input_sanity(text_input: str): 131 | if len(text_input) > 1024: 132 | raise gr.Error("please do not send more than 1024 characters to this demo") 133 | if text_input.count("<|image|>") != 0: 134 | raise gr.Error("please do not send <|image|> tokens to this demo") 135 | 136 | def stream_model_output(prompt, resolution, seed, gen_t, cfg, image_top_k): 137 | 138 | while True: 139 | content_piece = response_queue.get() 140 | if isinstance(content_piece, Ready): 141 | break 142 | for queue in request_queues: 143 | queue.put((prompt, resolution, seed, gen_t, cfg, image_top_k)) 144 | while True: 145 | content_piece = response_queue.get() 146 | if isinstance(content_piece, ModelFailure): 147 | raise RuntimeError 148 | if content_piece["end_of_content"]: 149 | yield content_piece["image"][0], content_piece["prompt"] 150 | break 151 | else: 152 | yield None, None 153 | 154 | def show_real_prompt(): 155 | return gr.update(visible=True) 156 | 157 | with gr.Blocks(css="#image_input {height: 100% !important}") as demo: 158 | gr.Markdown("# Lumina-mGPT Image Generation Demo\n") 159 | with gr.Row() as r: 160 | with gr.Column(scale=1): 161 | prompt = gr.Textbox(lines=3, interactive=True, label="Prompt") 162 | with gr.Row(): 163 | patch_size = 32 164 | res_choices = generate_crop_size_list((args.target_size // patch_size) ** 2, patch_size) 165 | res_choices = [f"{w}x{h}" for w, h in res_choices] 166 | assert f"{args.target_size}x{args.target_size}" in res_choices 167 | resolution = gr.Dropdown( 168 | value=f"{args.target_size}x{args.target_size}", choices=res_choices, label="Resolution" 169 | ) 170 | with gr.Row(): 171 | with gr.Column(scale=1): 172 | seed = gr.Slider( 173 | minimum=0, 174 | maximum=int(1e5), 175 | value=300, 176 | step=1, 177 | interactive=True, 178 | label="Seed (0 for random)", 179 | ) 180 | with gr.Column(scale=1): 181 | gen_t = gr.Slider( 182 | minimum=0.0, 183 | maximum=4.0, 184 | value=1.0, 185 | interactive=True, 186 | label="gen_t", 187 | ) 188 | with gr.Row(): 189 | with gr.Column(scale=1): 190 | cfg = gr.Slider( 191 | minimum=0.0, 192 | maximum=16.0, 193 | value=3.0, 194 | interactive=True, 195 | label="cfg", 196 | ) 197 | with gr.Column(scale=1): 198 | image_top_k = gr.Slider( 199 | minimum=0, 200 | maximum=8192, 201 | value=4000, 202 | interactive=True, 203 | label="Image Top-k", 204 | ) 205 | submit_button = gr.Button("Submit", variant="primary") 206 | 207 | with gr.Column(): 208 | output_img = gr.Image( 209 | label="Generated image", 210 | interactive=False, 211 | ) 212 | real_prompt = gr.Textbox( 213 | label="Real Prompt", interactive=False, visible=False, show_label=True, show_copy_button=True 214 | ) 215 | 216 | prompt.submit(check_input_sanity, [prompt], []).success( 217 | stream_model_output, [prompt, resolution, seed, gen_t, cfg, image_top_k], [output_img, real_prompt] 218 | ) 219 | submit_button.click(check_input_sanity, [prompt], []).success( 220 | stream_model_output, [prompt, resolution, seed, gen_t, cfg, image_top_k], [output_img, real_prompt] 221 | ).success(show_real_prompt, [], [real_prompt]) 222 | barrier.wait() 223 | demo.queue(api_open=True).launch( 224 | share=True, 225 | server_name="0.0.0.0", 226 | ) 227 | 228 | 229 | if __name__ == "__main__": 230 | parser = argparse.ArgumentParser("X-LLM-X Chat Demo") 231 | group = parser.add_mutually_exclusive_group() 232 | group.add_argument( 233 | "--gpu_ids", 234 | type=int, 235 | nargs="+", 236 | help="A list of space-separated gpu ids to run the model on. " 237 | "The model will span across GPUs in tensor-parallel mode.", 238 | ) 239 | group.add_argument( 240 | "--n_gpus", 241 | type=int, 242 | default=1, 243 | help="Number of GPUs to run the model on. Equivalent to " "--gpu_ids 0 1 2 ... n-1", 244 | ) 245 | parser.add_argument("--pretrained_path", type=str, required=True, help="Path to the model checkpoints.") 246 | parser.add_argument( 247 | "--precision", 248 | type=str, 249 | choices=["fp16", "bf16"], 250 | default="bf16", 251 | help="The dtype used for model weights and inference.", 252 | ) 253 | parser.add_argument( 254 | "--target_size", type=int, default=768, choices=[512, 768, 1024], help="The target image generation size." 255 | ) 256 | args = parser.parse_args() 257 | 258 | # check and setup gpu_ids to use 259 | if args.gpu_ids is None: 260 | if args.n_gpus is None: 261 | args.n_gpus = 1 262 | assert args.n_gpus > 0, "The demo currently must run on a positive number of GPUs." 263 | args.gpu_ids = list(range(args.n_gpus)) 264 | 265 | assert len(args.gpu_ids) == 1, "Currently only supports running on a single GPU." 266 | 267 | # using the default "fork" method messes up some imported libs (e.g., 268 | # pandas) 269 | mp.set_start_method("spawn") 270 | 271 | # setup the queues and start the model workers 272 | request_queues = [] 273 | response_queue = mp.Queue() 274 | worker_processes = [] 275 | barrier = mp.Barrier(len(args.gpu_ids) + 1) 276 | for rank, gpu_id in enumerate(args.gpu_ids): 277 | request_queue = mp.Queue() 278 | rank_response_queue = response_queue if rank == 0 else None 279 | process = mp.Process( 280 | target=model_worker, 281 | args=(rank, args, barrier, request_queue, rank_response_queue), 282 | ) 283 | process.start() 284 | worker_processes.append(process) 285 | request_queues.append(request_queue) 286 | 287 | gradio_worker(request_queues, response_queue, args, barrier) 288 | -------------------------------------------------------------------------------- /lumina_mgpt/exps/7B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lr=2e-5 4 | wd=0.1 5 | dropout=0.05 6 | z_loss_weight=1e-5 7 | 8 | data_config=configs/data/sample.yaml 9 | 10 | exp_name=7B 11 | mkdir -p output/"$exp_name" 12 | 13 | 14 | python -u finetune_solver.py \ 15 | --model_size 7B \ 16 | --batch_size 8 \ 17 | --accum_iter 1 \ 18 | --epochs 2 \ 19 | --warmup_epochs 0.01 \ 20 | --lr ${lr} \ 21 | --min_lr ${lr} \ 22 | --wd ${wd} \ 23 | --clip_grad 4 \ 24 | --data_config $data_config \ 25 | --cache_ann_on_disk \ 26 | --num_workers 8 \ 27 | --output_dir output/"$exp_name" \ 28 | --save_iteration_interval 1000 \ 29 | --checkpointing \ 30 | --max_seq_len 4096 \ 31 | --unmask_image_logits \ 32 | --dropout ${dropout} \ 33 | --z_loss_weight ${z_loss_weight} \ 34 | 2>&1 | tee -a output/"$exp_name"/output.log 35 | 36 | echo "exp name: $exp_name" 37 | -------------------------------------------------------------------------------- /lumina_mgpt/finetune_solver.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import List, Tuple 3 | 4 | from accelerate import init_empty_weights 5 | import torch 6 | 7 | from model import ChameleonXLLMXConfig, ChameleonXLLMXForConditionalGeneration 8 | from xllmx.data.item_processor import ItemProcessorBase 9 | from xllmx.solvers.finetune import FinetuneSolverBase 10 | 11 | 12 | class ItemProcessor(ItemProcessorBase): 13 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]: 14 | assert training_mode 15 | 16 | if "token" in data_item and "label" in data_item: 17 | data_item = data_item 18 | else: 19 | assert "file" in data_item 20 | with open(data_item["file"], "rb") as f: 21 | data_item = pickle.load(f) 22 | 23 | tokens = data_item["token"] 24 | labels = data_item["label"] 25 | assert len(tokens) == len(labels) 26 | 27 | return tokens, labels 28 | 29 | def predict_item_token_length(self, data_item: dict) -> int: 30 | if "token" in data_item: 31 | return len(data_item["token"]) 32 | elif "len" in data_item: 33 | return data_item["len"] 34 | else: 35 | raise ValueError() 36 | 37 | 38 | class Solver(FinetuneSolverBase): 39 | @classmethod 40 | def get_args_parser(cls): 41 | parser = super().get_args_parser() 42 | # task-specific parameters 43 | parser.add_argument("--max_seq_len", default=4096, type=int, help="max token length") 44 | parser.add_argument("--mask_image_logits", default=True) 45 | parser.add_argument("--unmask_image_logits", action="store_false", dest="mask_image_logits") 46 | parser.add_argument("--dropout", type=float, default=0.0) 47 | parser.add_argument("--z_loss_weight", type=float, default=0.0) 48 | parser.add_argument("--model_size", type=str, default="7B", choices=["7B", "34B"]) 49 | return parser 50 | 51 | def _model_func( 52 | self, 53 | init_from: str, 54 | ) -> (ChameleonXLLMXForConditionalGeneration, None): 55 | 56 | # Only instantiate the model on rank0 57 | # Other ranks will receive the model weights from rank0 during FSDP wrapping (through `sync_module_states`) 58 | # See https://github.com/pytorch/pytorch/issues/105840 59 | if self.dp_rank == 0: 60 | model = ChameleonXLLMXForConditionalGeneration.from_pretrained( 61 | init_from, 62 | max_position_embeddings=self.args.max_seq_len, 63 | mask_image_logits=self.args.mask_image_logits, 64 | dropout=self.args.dropout, 65 | z_loss_weight=self.args.z_loss_weight, 66 | torch_dtype=torch.bfloat16, 67 | device_map="cpu", 68 | ) 69 | else: 70 | with init_empty_weights(): 71 | config = ChameleonXLLMXConfig.from_pretrained( 72 | init_from, 73 | max_position_embeddings=self.args.max_seq_len, 74 | mask_image_logits=self.args.mask_image_logits, 75 | dropout=self.args.dropout, 76 | z_loss_weight=self.args.z_loss_weight, 77 | torch_dtype=torch.bfloat16, 78 | ) 79 | model = ChameleonXLLMXForConditionalGeneration(config) 80 | 81 | del model.model.vqmodel 82 | 83 | return model, None 84 | 85 | def _item_processor_func(self) -> ItemProcessorBase: 86 | return ItemProcessor() 87 | 88 | def _make_and_save_starting_point(self, save_path: str) -> None: 89 | 90 | pretrained_name = { 91 | "7B": "Alpha-VLLM/Chameleon_7B_mGPT", 92 | "34B": "Alpha-VLLM/Chameleon_34B_mGPT", 93 | }[self.args.model_size] 94 | 95 | model = ChameleonXLLMXForConditionalGeneration.from_pretrained( 96 | pretrained_name, 97 | max_position_embeddings=self.args.max_seq_len, 98 | mask_image_logits=self.args.mask_image_logits, 99 | dropout=self.args.dropout, 100 | z_loss_weight=self.args.z_loss_weight, 101 | torch_dtype=torch.bfloat16, 102 | device_map="cpu", 103 | ) 104 | 105 | image_tokens = model.model.vocabulary_mapping.image_tokens 106 | model.lm_head.weight.data[image_tokens] = torch.zeros_like(model.lm_head.weight.data[image_tokens]) 107 | 108 | model.save_pretrained(save_path, max_shard_size="10GB") 109 | 110 | 111 | if __name__ == "__main__": 112 | args = Solver.get_args_parser().parse_args() 113 | solver = Solver(args) 114 | solver.run() 115 | -------------------------------------------------------------------------------- /lumina_mgpt/generate_examples/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | import argparse 6 | 7 | from PIL import Image 8 | import torch 9 | 10 | from inference_solver import FlexARInferenceSolver 11 | from xllmx.util.misc import random_seed 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model_path", type=str, required=True) 17 | parser.add_argument("--save_path", type=str, required=True) 18 | parser.add_argument("--temperature", type=float) 19 | parser.add_argument("--top_k", type=int) 20 | parser.add_argument("--cfg", type=float) 21 | parser.add_argument("-n", type=int, default=5) 22 | parser.add_argument("--width", type=int, default=512) 23 | parser.add_argument("--height", type=int, default=512) 24 | 25 | args = parser.parse_args() 26 | 27 | print("args:\n", args) 28 | 29 | select_set1 = [ 30 | "Image of a dog playing water, and a water fall is in the background.", 31 | "A family of asian people sitting around the dinner table, eating and laughing.", 32 | "A high-resolution photograph of a middle-aged woman with curly hair, wearing traditional Japanese kimono, smiling gently under a cherry blossom tree in full bloom.", # noqa 33 | "Image of a bustling downtown street in Tokyo at night, with neon signs, crowded sidewalks, and tall skyscrapers.", # noqa 34 | "Image of a quiet European village with cobblestone streets and colorful houses, under a clear blue sky.", 35 | ] 36 | 37 | l_prompts = select_set1 38 | 39 | t = args.temperature 40 | top_k = args.top_k 41 | cfg = args.cfg 42 | n = args.n 43 | w, h = args.width, args.height 44 | 45 | inference_solver = FlexARInferenceSolver( 46 | model_path=args.model_path, 47 | precision="bf16", 48 | ) 49 | 50 | with torch.no_grad(): 51 | l_generated_all = [] 52 | for i, prompt in enumerate(l_prompts): 53 | for repeat_idx in range(n): 54 | random_seed(repeat_idx) 55 | generated = inference_solver.generate( 56 | images=[], 57 | qas=[[f"Generate an image of {w}x{h} according to the following prompt:\n{prompt}", None]], 58 | max_gen_len=8192, 59 | temperature=t, 60 | logits_processor=inference_solver.create_logits_processor(cfg=cfg, image_top_k=top_k), 61 | ) 62 | try: 63 | l_generated_all.append(generated[1][0]) 64 | except: 65 | l_generated_all.append(Image.new("RGB", (w, h))) 66 | 67 | result_image = inference_solver.create_image_grid(l_generated_all, len(l_prompts), n) 68 | result_image.save(args.save_path) 69 | -------------------------------------------------------------------------------- /lumina_mgpt/inference_solver.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | from typing import List, Optional, Union 5 | 6 | from PIL import Image 7 | import torch 8 | import transformers 9 | from transformers import GenerationConfig, TextStreamer 10 | from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList, LogitsWarper 11 | 12 | from data.item_processor import FlexARItemProcessor 13 | from model.chameleon import ChameleonForConditionalGeneration 14 | 15 | 16 | class LLMImageStartTriggeredUnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): 17 | r""" 18 | Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores 19 | from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`. 20 | The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch. 21 | 22 | See [the paper](https://arxiv.org/abs/2306.17806) for more information. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | guidance_scale: float, 28 | model, 29 | image_start_token_id, 30 | image_end_token_id, 31 | image_next_line_token_id, 32 | patch_size, 33 | unconditional_ids: Optional[torch.LongTensor] = None, 34 | unconditional_attention_mask: Optional[torch.LongTensor] = None, 35 | use_cache: Optional[bool] = True, 36 | ): 37 | self.guidance_scale = guidance_scale 38 | self.model = model 39 | self.unconditional_context_backup = { 40 | "input_ids": unconditional_ids, 41 | "attention_mask": unconditional_attention_mask, 42 | "use_cache": use_cache, 43 | "past_key_values": transformers.DynamicCache() if use_cache else None, 44 | "first_pass": True, 45 | } 46 | self.unconditional_context = None 47 | 48 | self.nums_image_start_tokens = None 49 | 50 | self.image_start_token_id = image_start_token_id 51 | self.image_end_token_id = image_end_token_id 52 | self.image_next_line_token_id = image_next_line_token_id 53 | self.image_start_token_id_index = None 54 | self.patch_size = patch_size 55 | self.h_latent_dim = None 56 | self.w_latent_dim = None 57 | 58 | def get_unconditional_logits(self, input_ids, image_start_token_id_index): 59 | 60 | if self.unconditional_context["first_pass"]: 61 | if self.unconditional_context["input_ids"] is None: 62 | self.unconditional_context["input_ids"] = input_ids[:, image_start_token_id_index:] 63 | if self.unconditional_context["attention_mask"] is None: 64 | self.unconditional_context["attention_mask"] = torch.ones_like( 65 | self.unconditional_context["input_ids"], dtype=torch.long 66 | ) 67 | input_ids = self.unconditional_context["input_ids"] 68 | attention_mask = self.unconditional_context["attention_mask"] 69 | self.unconditional_context["first_pass"] = False 70 | else: 71 | attention_mask = torch.cat( 72 | [ 73 | self.unconditional_context["attention_mask"], 74 | torch.ones_like(input_ids[:, -1:], dtype=torch.long), 75 | ], 76 | dim=1, 77 | ) 78 | if not self.unconditional_context["use_cache"]: 79 | input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1) 80 | else: 81 | input_ids = input_ids[:, -1:] 82 | self.unconditional_context["input_ids"] = input_ids 83 | self.unconditional_context["attention_mask"] = attention_mask 84 | 85 | out = self.model( 86 | input_ids, 87 | attention_mask=attention_mask, 88 | use_cache=self.unconditional_context["use_cache"], 89 | past_key_values=self.unconditional_context["past_key_values"], 90 | ) 91 | self.unconditional_context["past_key_values"] = out.get("past_key_values", None) 92 | 93 | return out.logits 94 | 95 | def __call__(self, input_ids, scores): 96 | num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum() 97 | num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum() 98 | 99 | if num_image_start_tokens == num_image_end_tokens: 100 | self.h_latent_dim, self.w_latent_dim = None, None 101 | self.image_start_token_id_index = None 102 | self.unconditional_context = None 103 | return scores 104 | 105 | elif num_image_start_tokens == num_image_end_tokens + 1: 106 | if self.image_start_token_id_index is None: 107 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0][-1].item() 108 | new_token_num = len(input_ids[0][self.image_start_token_id_index + 1 :]) 109 | if new_token_num >= 2: 110 | if self.h_latent_dim is None or self.w_latent_dim is None: 111 | h_grids, w_grids = ( 112 | input_ids[0][self.image_start_token_id_index + 1] - 8804, 113 | input_ids[0][self.image_start_token_id_index + 2] - 8804, 114 | ) 115 | self.h_latent_dim, self.w_latent_dim = h_grids * 2, w_grids * 2 116 | 117 | if self.unconditional_context is None: 118 | self.unconditional_context = copy.deepcopy(self.unconditional_context_backup) 119 | 120 | if self.guidance_scale == 1.0: 121 | return scores 122 | 123 | unconditional_logits = self.get_unconditional_logits(input_ids, self.image_start_token_id_index)[:, -1] 124 | 125 | scores_processed = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits 126 | return scores_processed 127 | 128 | else: 129 | print("Something wrong in the decoding process.") 130 | 131 | return scores 132 | 133 | 134 | class MultiModalLogitsProcessor(LogitsProcessor): 135 | 136 | def __init__( 137 | self, 138 | image_start_token_id=None, 139 | image_end_token_id=None, 140 | image_next_line_token_id=None, 141 | patch_size=None, 142 | voc_size=None, 143 | ): 144 | self.image_start_token_id = image_start_token_id 145 | self.image_end_token_id = image_end_token_id 146 | self.image_next_line_token_id = image_next_line_token_id 147 | self.image_start_token_id_index = None 148 | self.patch_size = patch_size 149 | self.h_latent_dim = None 150 | self.w_latent_dim = None 151 | 152 | self.vocab_list = [i for i in range(voc_size)] 153 | self.image_token_list = [i for i in range(4, 8195 + 1)] 154 | self.suppress_tokens = torch.tensor( 155 | [x for x in self.vocab_list if x not in self.image_token_list], device="cuda" 156 | ) 157 | 158 | self.vocab_tensor = torch.arange(voc_size, device="cuda") 159 | self.suppress_token_mask = torch.isin(self.vocab_tensor, self.suppress_tokens) 160 | self.new_line_force_token_mask = torch.isin( 161 | self.vocab_tensor, torch.tensor([self.image_next_line_token_id], device="cuda") 162 | ) 163 | self.eos_image_force_token_mask = torch.isin( 164 | self.vocab_tensor, torch.tensor([self.image_end_token_id], device="cuda") 165 | ) 166 | 167 | self.flag = False 168 | self.num_image_start_tokens = None 169 | self.num_image_end_tokens = None 170 | 171 | # @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 172 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 173 | 174 | self.num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum() 175 | self.num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum() 176 | 177 | # print(self.num_image_start_tokens, self.num_image_end_tokens) 178 | 179 | if self.num_image_start_tokens == self.num_image_end_tokens: 180 | self.h_latent_dim, self.w_latent_dim = None, None 181 | self.image_start_token_id_index = None 182 | return scores 183 | 184 | elif self.num_image_start_tokens == self.num_image_end_tokens + 1: 185 | if self.image_start_token_id_index is None: 186 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0] 187 | print(self.image_start_token_id_index) 188 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0][-1].item() 189 | 190 | new_token_num = len(input_ids[0][self.image_start_token_id_index + 1 :]) 191 | # print(f"num new tokens: {new_token_num}") 192 | if new_token_num >= 2: 193 | if self.h_latent_dim is None or self.w_latent_dim is None: 194 | h_grids, w_grids = ( 195 | input_ids[0][self.image_start_token_id_index + 1] - 8804, 196 | input_ids[0][self.image_start_token_id_index + 2] - 8804, 197 | ) 198 | # print(f"h_grids: {h_grids}, w_grids: {w_grids}") 199 | self.h_latent_dim, self.w_latent_dim = h_grids * 2, w_grids * 2 200 | print(f"h_latent_dim: {self.h_latent_dim}, w_latent_dim: {self.w_latent_dim}") 201 | 202 | tokens = input_ids[0][self.image_start_token_id_index + 3 :] 203 | if (len(tokens) + 1) % (self.w_latent_dim + 1) == 0: 204 | new_line_constrained_scores = torch.full_like(scores, -math.inf) 205 | new_line_constrained_scores[:, self.image_next_line_token_id] = 0 206 | print(f"new line: {len(tokens)+1}") 207 | return new_line_constrained_scores 208 | elif (len(tokens) + 1) == (self.w_latent_dim + 1) * self.h_latent_dim + 1: 209 | eos_image_constrained_scores = torch.full_like(scores, -math.inf) 210 | eos_image_constrained_scores[:, self.image_end_token_id] = 0 211 | print(f"eos image: {len(tokens)+1}") 212 | return eos_image_constrained_scores 213 | elif (len(tokens) + 1) % (self.w_latent_dim + 1) != 0: 214 | image_constrained_scores = torch.where(self.suppress_token_mask, -float("inf"), scores) 215 | return image_constrained_scores 216 | else: 217 | print("Something wrong in the decoding process.") 218 | 219 | return scores 220 | 221 | 222 | class InterleavedTopKLogitsWarper(LogitsWarper): 223 | r""" 224 | [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together 225 | with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. 226 | """ 227 | 228 | def __init__( 229 | self, 230 | image_top_k: int, 231 | text_top_k: int, 232 | image_start_token_id=None, 233 | image_end_token_id=None, 234 | filter_value: float = -float("Inf"), 235 | min_tokens_to_keep: int = 1, 236 | ): 237 | if not isinstance(text_top_k, int) or text_top_k <= 0: 238 | raise ValueError(f"`text_top_k` has to be a strictly positive integer, but is {text_top_k}") 239 | if not isinstance(image_top_k, int) or text_top_k <= 0: 240 | raise ValueError(f"`image_top_k` has to be a strictly positive integer, but is {image_top_k}") 241 | 242 | self.image_top_k = max(image_top_k, min_tokens_to_keep) 243 | self.text_top_k = max(text_top_k, min_tokens_to_keep) 244 | self.filter_value = filter_value 245 | 246 | self.image_start_token_id = image_start_token_id 247 | self.image_end_token_id = image_end_token_id 248 | 249 | self.flag = False 250 | self.num_image_start_tokens = None 251 | self.num_image_end_tokens = None 252 | 253 | # @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 254 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 255 | 256 | self.num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum() 257 | self.num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum() 258 | 259 | if self.num_image_start_tokens == self.num_image_end_tokens + 1: 260 | top_k = min(self.image_top_k, scores.size(-1)) 261 | else: 262 | top_k = min(self.text_top_k, scores.size(-1)) # Safety check 263 | # Remove all tokens with a probability less than the last token of the top-k 264 | indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] 265 | scores_processed = scores.masked_fill(indices_to_remove, self.filter_value) 266 | return scores_processed 267 | 268 | 269 | class FlexARInferenceSolver: 270 | @classmethod 271 | def get_args_parser(cls): 272 | parser = argparse.ArgumentParser("xllmx Inference", add_help=False) 273 | parser.add_argument("--model_path", type=str) 274 | parser.add_argument("--precision", type=str, choices=["fp16", "bf16", "tf32"], default="bf16") 275 | 276 | return parser 277 | 278 | def __init__(self, model_path, precision, target_size=512): 279 | self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 280 | 281 | self.model = ChameleonForConditionalGeneration.from_pretrained( 282 | model_path, 283 | torch_dtype=self.dtype, 284 | device_map="cuda", 285 | ) 286 | self.item_processor = FlexARItemProcessor(target_size=target_size) 287 | 288 | def get_streamer(self): 289 | return TextStreamer(self.item_processor.tokenizer) 290 | 291 | @torch.no_grad() 292 | def generate( 293 | self, 294 | images: Image.Image | str | List[Union[Image.Image, str]], 295 | qas, 296 | max_gen_len, 297 | temperature, 298 | logits_processor=None, 299 | streamer=None, 300 | ): 301 | 302 | conversations = [] 303 | for q, a in qas: 304 | conversations.append( 305 | { 306 | "from": "human", 307 | "value": q, 308 | } 309 | ) 310 | conversations.append( 311 | { 312 | "from": "gpt", 313 | "value": a, 314 | } 315 | ) 316 | item = {"image": images, "conversations": conversations} 317 | 318 | _prompt = self.item_processor.process_item(item) 319 | prompt = [] 320 | for value in _prompt: 321 | if isinstance(value, int): 322 | prompt.append(value) 323 | else: 324 | prompt += value["input_ids"] 325 | prompt_len = len(prompt) 326 | prompt = torch.tensor(prompt, dtype=torch.int64, device=self.model.device).unsqueeze(0) 327 | 328 | generation_config = GenerationConfig( 329 | max_new_tokens=max_gen_len, 330 | max_length=self.model.config.max_position_embeddings, 331 | temperature=temperature, 332 | top_k=None, 333 | do_sample=True, 334 | eos_token_id=[8710], 335 | ) 336 | 337 | if logits_processor is None: 338 | logits_processor = self.create_logits_processor() 339 | 340 | with torch.cuda.amp.autocast(dtype=self.dtype): 341 | generation_result = self.model.generate( 342 | prompt, generation_config, logits_processor=logits_processor, streamer=streamer 343 | )[0][prompt_len:].tolist() 344 | if len(generation_result) > 0 and generation_result[-1] == 8710: 345 | generation_result = generation_result[:-1] 346 | 347 | return self.decode_ids(generation_result) 348 | 349 | def decode_ids(self, tokens: List[int]): 350 | generated_images = [] 351 | generation_result_processed = [] 352 | i = 0 353 | while i < len(tokens): 354 | token_id = tokens[i] 355 | if token_id == self.item_processor.token2id(self.item_processor.image_start_token): 356 | cache = [] 357 | for j in range(i + 1, len(tokens)): 358 | if tokens[j] != self.item_processor.token2id(self.item_processor.image_end_token): 359 | cache.append(tokens[j]) 360 | i = j + 1 361 | else: 362 | image = self.decode_image(cache) 363 | generated_images.append(image) 364 | generation_result_processed.append(self.item_processor.token2id("<|image|>")) 365 | i = j + 1 366 | break 367 | else: 368 | generation_result_processed.append(token_id) 369 | i += 1 370 | 371 | generated = self.item_processor.tokenizer.decode(generation_result_processed) 372 | 373 | return generated, generated_images 374 | 375 | def decode_image(self, tokens: List[int]): 376 | return self.item_processor.decode_image(tokens) 377 | 378 | @staticmethod 379 | def create_image_grid(images, rows, cols): 380 | width, height = images[0].size 381 | 382 | grid_img = Image.new("RGB", (cols * width, rows * height)) 383 | 384 | for i, img in enumerate(images): 385 | row = i // cols 386 | col = i % cols 387 | grid_img.paste(img, (col * width, row * height)) 388 | 389 | return grid_img 390 | 391 | def create_logits_processor(self, cfg=3.0, image_top_k=2000, text_top_k=10): 392 | logits_processor = LogitsProcessorList() 393 | 394 | cfg_processor = LLMImageStartTriggeredUnbatchedClassifierFreeGuidanceLogitsProcessor( 395 | guidance_scale=cfg, 396 | model=self.model, 397 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token), 398 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token), 399 | image_next_line_token_id=self.item_processor.token2id(self.item_processor.new_line_token), 400 | patch_size=32, 401 | ) 402 | 403 | candidate_processor = MultiModalLogitsProcessor( 404 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token), 405 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token), 406 | image_next_line_token_id=self.item_processor.token2id(self.item_processor.new_line_token), 407 | patch_size=32, 408 | voc_size=self.model.config.vocab_size, 409 | ) 410 | 411 | topk_processor = InterleavedTopKLogitsWarper( 412 | image_top_k=image_top_k, 413 | text_top_k=text_top_k, 414 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token), 415 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token), 416 | ) 417 | 418 | logits_processor.append(cfg_processor) 419 | logits_processor.append(candidate_processor) 420 | logits_processor.append(topk_processor) 421 | 422 | return logits_processor 423 | 424 | 425 | if __name__ == "__main__": 426 | parser = FlexARInferenceSolver.get_args_parser() 427 | args = parser.parse_args() 428 | solver = FlexARInferenceSolver(**vars(args)) 429 | -------------------------------------------------------------------------------- /lumina_mgpt/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_xllmx_chameleon import ChameleonXLLMXConfig 2 | from .modeling_xllmx_chameleon import ChameleonXLLMXForConditionalGeneration 3 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available 17 | 18 | _import_structure = { 19 | "configuration_chameleon": ["ChameleonConfig", "ChameleonVQVAEConfig"], 20 | "processing_chameleon": ["ChameleonProcessor"], 21 | } 22 | 23 | 24 | try: 25 | if not is_torch_available(): 26 | raise OptionalDependencyNotAvailable() 27 | except OptionalDependencyNotAvailable: 28 | pass 29 | else: 30 | _import_structure["modeling_chameleon"] = [ 31 | "ChameleonForConditionalGeneration", 32 | "ChameleonModel", 33 | "ChameleonPreTrainedModel", 34 | "ChameleonVQVAE", 35 | ] 36 | 37 | try: 38 | if not is_vision_available(): 39 | raise OptionalDependencyNotAvailable() 40 | except OptionalDependencyNotAvailable: 41 | pass 42 | else: 43 | _import_structure["image_processing_chameleon"] = ["ChameleonImageProcessor"] 44 | 45 | 46 | if TYPE_CHECKING: 47 | from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig 48 | from .processing_chameleon import ChameleonProcessor 49 | 50 | try: 51 | if not is_torch_available(): 52 | raise OptionalDependencyNotAvailable() 53 | except OptionalDependencyNotAvailable: 54 | pass 55 | else: 56 | from .modeling_chameleon import ( 57 | ChameleonForConditionalGeneration, 58 | ChameleonModel, 59 | ChameleonPreTrainedModel, 60 | ChameleonVQVAE, 61 | ) 62 | 63 | try: 64 | if not is_vision_available(): 65 | raise OptionalDependencyNotAvailable() 66 | except OptionalDependencyNotAvailable: 67 | pass 68 | else: 69 | from .image_processing_chameleon import ChameleonImageProcessor 70 | 71 | 72 | else: 73 | import sys 74 | 75 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 76 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon/configuration_chameleon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """chameleon model configuration""" 16 | 17 | from typing import List 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | 25 | class ChameleonVQVAEConfig(PretrainedConfig): 26 | r""" 27 | This is the configuration class to store the configuration of a [`ChameleonVQModel`]. It is used to instantiate a 28 | `ChameleonVQModel` according to the specified arguments, defining the model architecture. 29 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 30 | documentation from [`PretrainedConfig`] for more information. Instantiating a 31 | configuration with the defaults will yield a similar configuration to the VQModel of the 32 | [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B). 33 | 34 | Args: 35 | embed_dim (`int`, *optional*, defaults to 256): 36 | Dimensionality of each embedding vector. 37 | num_embeddings (`int`, *optional*, defaults to 8192): 38 | Number of codebook embeddings. 39 | double_latent (`bool`, *optional*, defaults to `False`): 40 | Whether to use double z channels. 41 | latent_channels (`int`, *optional*, defaults to 256): 42 | Number of channels for the latent space. 43 | resolution (`int`, *optional*, defaults to 512): 44 | Resolution of the input images. 45 | in_channels (`int`, *optional*, defaults to 3): 46 | Number of input channels. 47 | base_channels (`int`, *optional*, defaults to 128): 48 | Base channel count. 49 | channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`): 50 | Channel multipliers for each resolution. 51 | num_res_blocks (`int`, *optional*, defaults to 2): 52 | Number of residual blocks. 53 | attn_resolutions (`List[int]`, *optional*): 54 | Resolutions to apply attention. 55 | dropout (`float`, *optional*, defaults to 0.0): 56 | Dropout rate. 57 | attn_type (`str`, *optional*, defaults to `"vanilla"`): 58 | Attention type used in VQ-GAN encoder. Can be "vanilla" or None. 59 | initializer_range (`float`, *optional*, defaults to 0.02): 60 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 61 | """ 62 | 63 | model_type = "chameleon_vqgan" 64 | 65 | def __init__( 66 | self, 67 | embed_dim: int = 256, 68 | num_embeddings: int = 8192, 69 | double_latent: bool = False, 70 | latent_channels: int = 256, 71 | resolution: int = 512, 72 | in_channels: int = 3, 73 | base_channels: int = 128, 74 | channel_multiplier: List[int] = [1, 1, 2, 2, 4], 75 | num_res_blocks: int = 2, 76 | attn_resolutions: List[int] = None, 77 | dropout: float = 0.0, 78 | attn_type: str = "vanilla", 79 | initializer_range=0.02, 80 | **kwargs, 81 | ): 82 | super().__init__(**kwargs) 83 | self.embed_dim = embed_dim 84 | self.num_embeddings = num_embeddings 85 | self.double_latent = double_latent 86 | self.latent_channels = latent_channels 87 | self.resolution = resolution 88 | self.in_channels = in_channels 89 | self.base_channels = base_channels 90 | self.channel_multiplier = channel_multiplier 91 | self.num_res_blocks = num_res_blocks 92 | self.attn_resolutions = attn_resolutions 93 | self.dropout = dropout 94 | self.attn_type = attn_type 95 | self.initializer_range = initializer_range 96 | 97 | 98 | class ChameleonConfig(PretrainedConfig): 99 | r""" 100 | This is the configuration class to store the configuration of a [`ChameleonModel`]. It is used to instantiate a 101 | chameleon model according to the specified arguments, defining the model architecture. Instantiating a 102 | configuration with the defaults will yield a similar configuration to that of the 103 | [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B). 104 | 105 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 106 | documentation from [`PretrainedConfig`] for more information. 107 | 108 | 109 | Args: 110 | vocab_size (`int`, *optional*, defaults to 65536): 111 | Vocabulary size of the chameleon model. Defines the number of different tokens that can be represented by the 112 | `inputs_ids` passed when calling [`ChameleonModel`]; this includes text and image tokens. 113 | hidden_size (`int`, *optional*, defaults to 4096): 114 | Dimension of the hidden representations. 115 | intermediate_size (`int`, *optional*, defaults to 11008): 116 | Dimension of the MLP representations. 117 | num_hidden_layers (`int`, *optional*, defaults to 32): 118 | Number of hidden layers in the Transformer decoder. 119 | num_attention_heads (`int`, *optional*, defaults to 32): 120 | Number of attention heads for each attention layer in the Transformer decoder. 121 | num_key_value_heads (`int`, *optional*, defaults to 32): 122 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 123 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 124 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 125 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 126 | by meanpooling all the original heads within that group. For more details checkout [this 127 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 128 | `num_attention_heads`. 129 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 130 | The non-linear activation function (function or string) in the decoder. 131 | max_position_embeddings (`int`, *optional*, defaults to 4096): 132 | The maximum sequence length that this model might ever be used with. Chameleon supports up to 4096 tokens. 133 | initializer_range (`float`, *optional*, defaults to 0.02): 134 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 135 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 136 | The epsilon used by the rms normalization layers. 137 | use_cache (`bool`, *optional*, defaults to `True`): 138 | Whether or not the model should return the last key/values attentions (not used by all models). Only 139 | relevant if `config.is_decoder=True`. 140 | pad_token_id (`int`, *optional*): 141 | Padding token id. 142 | bos_token_id (`int`, *optional*, defaults to 1): 143 | Beginning of stream token id. 144 | eos_token_id (`int`, *optional*, defaults to 2): 145 | End of stream token id. 146 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 147 | Whether to tie weight embeddings 148 | rope_theta (`float`, *optional*, defaults to 10000.0): 149 | The base period of the RoPE embeddings. 150 | rope_scaling (`Dict`, *optional*): 151 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 152 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 153 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 154 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 155 | these scaling strategies behave: 156 | https://www.reddit.com/r/Localchameleon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 157 | experimental feature, subject to breaking API changes in future versions. 158 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 159 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 160 | attention_dropout (`float`, *optional*, defaults to 0.0): 161 | The dropout ratio for the attention probabilities. 162 | model_parallel_size (`int`, *optional*, defaults to 1): 163 | Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference 164 | doesn't do reduction in those layers and each rank has its own biases. 165 | swin_norm (`bool`, *optional*, defaults to `False`): 166 | Use Swin Transformer normalization. 167 | vq_config (`dict`, *optional*): 168 | ChameleonVQConfig instance containing the configuration for the VQ-VAE model. 169 | vocabulary_map (`dict`, *optional*): 170 | A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. 171 | mlp_bias (`bool`, *optional*, defaults to `False`): 172 | Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. 173 | 174 | 175 | ```python 176 | >>> from transformers import ChameleonModel, ChameleonConfig 177 | 178 | >>> # Initializing a chameleon chameleon-7b style configuration 179 | >>> configuration = ChameleonConfig() 180 | 181 | >>> # Initializing a model from the chameleon-7b style configuration 182 | >>> model = ChameleonModel(configuration) 183 | 184 | >>> # Accessing the model configuration 185 | >>> configuration = model.config 186 | ```""" 187 | 188 | model_type = "chameleon" 189 | keys_to_ignore_at_inference = ["past_key_values"] 190 | 191 | def __init__( 192 | self, 193 | vocab_size=65536, 194 | hidden_size=4096, 195 | intermediate_size=11008, 196 | num_hidden_layers=32, 197 | num_attention_heads=32, 198 | num_key_value_heads=32, 199 | hidden_act="silu", 200 | max_position_embeddings=4096, 201 | initializer_range=0.02, 202 | rms_norm_eps=1e-05, 203 | use_cache=True, 204 | pad_token_id=None, 205 | bos_token_id=1, 206 | eos_token_id=2, 207 | tie_word_embeddings=False, 208 | rope_theta=10000.0, 209 | rope_scaling=None, 210 | attention_bias=False, 211 | attention_dropout=0.0, 212 | model_parallel_size=1, 213 | swin_norm=False, 214 | vq_config=None, 215 | vocabulary_map=None, 216 | mlp_bias=False, 217 | mask_image_logits=True, 218 | dropout=0.0, 219 | **kwargs, 220 | ): 221 | self.vocab_size = vocab_size 222 | self.max_position_embeddings = max_position_embeddings 223 | self.hidden_size = hidden_size 224 | self.intermediate_size = intermediate_size 225 | self.num_hidden_layers = num_hidden_layers 226 | self.num_attention_heads = num_attention_heads 227 | self.mlp_bias = mlp_bias 228 | 229 | self.num_key_value_heads = num_key_value_heads 230 | self.hidden_act = hidden_act 231 | self.initializer_range = initializer_range 232 | self.rms_norm_eps = rms_norm_eps 233 | self.use_cache = use_cache 234 | self.rope_theta = rope_theta 235 | self.rope_scaling = rope_scaling 236 | self._rope_scaling_validation() 237 | self.attention_bias = attention_bias 238 | self.attention_dropout = attention_dropout 239 | self.model_parallel_size = model_parallel_size 240 | self.swin_norm = swin_norm 241 | self.mask_image_logits = mask_image_logits 242 | 243 | if vq_config is None: 244 | vq_config = {} 245 | logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.") 246 | 247 | self.vq_config = ChameleonVQVAEConfig(**vq_config) 248 | 249 | self.vocabulary_map = vocabulary_map 250 | 251 | self.dropout = dropout 252 | 253 | super().__init__( 254 | pad_token_id=pad_token_id, 255 | bos_token_id=bos_token_id, 256 | eos_token_id=eos_token_id, 257 | tie_word_embeddings=tie_word_embeddings, 258 | **kwargs, 259 | ) 260 | 261 | def _rope_scaling_validation(self): 262 | """ 263 | Validate the `rope_scaling` configuration. 264 | """ 265 | if self.rope_scaling is None: 266 | return 267 | 268 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 269 | raise ValueError( 270 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 271 | f"got {self.rope_scaling}" 272 | ) 273 | rope_scaling_type = self.rope_scaling.get("type", None) 274 | rope_scaling_factor = self.rope_scaling.get("factor", None) 275 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 276 | raise ValueError( 277 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 278 | ) 279 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 280 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 281 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon/image_processing_chameleon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Image processor class for Chameleon.""" 16 | 17 | from typing import Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict 21 | from transformers.image_transforms import get_resize_output_image_size, resize, to_channel_dimension_format 22 | from transformers.image_utils import ( 23 | ChannelDimension, 24 | ImageInput, 25 | PILImageResampling, 26 | infer_channel_dimension_format, 27 | is_scaled_image, 28 | is_valid_image, 29 | to_numpy_array, 30 | valid_images, 31 | validate_kwargs, 32 | validate_preprocess_arguments, 33 | ) 34 | from transformers.utils import TensorType, is_vision_available, logging 35 | 36 | logger = logging.get_logger(__name__) 37 | 38 | if is_vision_available(): 39 | import PIL 40 | 41 | 42 | def make_batched_images(images) -> List[List[ImageInput]]: 43 | """ 44 | Accepts images in list or nested list format, and makes a list of images for preprocessing. 45 | 46 | Args: 47 | images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): 48 | The input image. 49 | 50 | Returns: 51 | list: A list of images. 52 | """ 53 | if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): 54 | return [img for img_list in images for img in img_list] 55 | 56 | elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): 57 | return images 58 | 59 | elif is_valid_image(images): 60 | return [images] 61 | 62 | raise ValueError(f"Could not make batched video from {images}") 63 | 64 | 65 | class ChameleonImageProcessor(BaseImageProcessor): 66 | r""" 67 | Constructs a Chameleon image processor. 68 | 69 | Args: 70 | do_resize (`bool`, *optional*, defaults to `True`): 71 | Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by 72 | `do_resize` in the `preprocess` method. 73 | size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 512}`): 74 | Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with 75 | the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` 76 | method. 77 | resample (`PILImageResampling`, *optional*, defaults to 1): 78 | Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. 79 | do_center_crop (`bool`, *optional*, defaults to `True`): 80 | Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the 81 | `preprocess` method. 82 | crop_size (`Dict[str, int]` *optional*, defaults to {"height": 512, "width": 512}): 83 | Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` 84 | method. 85 | do_rescale (`bool`, *optional*, defaults to `True`): 86 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in 87 | the `preprocess` method. 88 | rescale_factor (`int` or `float`, *optional*, defaults to 0.0078): 89 | Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` 90 | method. 91 | do_normalize (`bool`, *optional*, defaults to `True`): 92 | Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. 93 | image_mean (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`): 94 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 95 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 96 | image_std (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`): 97 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 98 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 99 | Can be overridden by the `image_std` parameter in the `preprocess` method. 100 | do_convert_rgb (`bool`, *optional*, defaults to `True`): 101 | Whether to convert the image to RGB. 102 | """ 103 | 104 | model_input_names = ["pixel_values"] 105 | 106 | def __init__( 107 | self, 108 | do_resize: bool = True, 109 | size: Dict[str, int] = None, 110 | resample: PILImageResampling = PIL.Image.LANCZOS, 111 | do_center_crop: bool = True, 112 | crop_size: Dict[str, int] = None, 113 | do_rescale: bool = True, 114 | rescale_factor: Union[int, float] = 0.0078, 115 | do_normalize: bool = True, 116 | image_mean: Optional[Union[float, List[float]]] = None, 117 | image_std: Optional[Union[float, List[float]]] = None, 118 | do_convert_rgb: bool = True, 119 | **kwargs, 120 | ) -> None: 121 | super().__init__(**kwargs) 122 | size = size if size is not None else {"shortest_edge": 512} 123 | size = get_size_dict(size, default_to_square=False) 124 | crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512} 125 | crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") 126 | 127 | self.do_resize = do_resize 128 | self.size = size 129 | self.resample = resample 130 | self.do_center_crop = do_center_crop 131 | self.crop_size = crop_size 132 | self.do_rescale = do_rescale 133 | self.rescale_factor = rescale_factor 134 | self.do_normalize = do_normalize 135 | self.image_mean = image_mean if image_mean is not None else [1.0, 1.0, 1.0] 136 | self.image_std = image_std if image_std is not None else [1.0, 1.0, 1.0] 137 | self.do_convert_rgb = do_convert_rgb 138 | self._valid_processor_keys = [ 139 | "images", 140 | "do_resize", 141 | "size", 142 | "resample", 143 | "do_center_crop", 144 | "crop_size", 145 | "do_rescale", 146 | "rescale_factor", 147 | "do_normalize", 148 | "image_mean", 149 | "image_std", 150 | "do_convert_rgb", 151 | "return_tensors", 152 | "data_format", 153 | "input_data_format", 154 | ] 155 | 156 | # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize 157 | def resize( 158 | self, 159 | image: np.ndarray, 160 | size: Dict[str, int], 161 | resample: PILImageResampling = PILImageResampling.BICUBIC, 162 | data_format: Optional[Union[str, ChannelDimension]] = None, 163 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 164 | **kwargs, 165 | ) -> np.ndarray: 166 | """ 167 | Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge 168 | resized to keep the input aspect ratio. 169 | 170 | Args: 171 | image (`np.ndarray`): 172 | Image to resize. 173 | size (`Dict[str, int]`): 174 | Size of the output image. 175 | resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): 176 | Resampling filter to use when resiizing the image. 177 | data_format (`str` or `ChannelDimension`, *optional*): 178 | The channel dimension format of the image. If not provided, it will be the same as the input image. 179 | input_data_format (`ChannelDimension` or `str`, *optional*): 180 | The channel dimension format of the input image. If not provided, it will be inferred. 181 | """ 182 | default_to_square = True 183 | if "shortest_edge" in size: 184 | size = size["shortest_edge"] 185 | default_to_square = False 186 | elif "height" in size and "width" in size: 187 | size = (size["height"], size["width"]) 188 | else: 189 | raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") 190 | 191 | output_size = get_resize_output_image_size( 192 | image, 193 | size=size, 194 | default_to_square=default_to_square, 195 | input_data_format=input_data_format, 196 | ) 197 | return resize( 198 | image, 199 | size=output_size, 200 | resample=resample, 201 | data_format=data_format, 202 | input_data_format=input_data_format, 203 | **kwargs, 204 | ) 205 | 206 | def preprocess( 207 | self, 208 | images: ImageInput, 209 | do_resize: bool = None, 210 | size: Dict[str, int] = None, 211 | resample: PILImageResampling = None, 212 | do_center_crop: bool = None, 213 | crop_size: int = None, 214 | do_rescale: bool = None, 215 | rescale_factor: float = None, 216 | do_normalize: bool = None, 217 | image_mean: Optional[Union[float, List[float]]] = None, 218 | image_std: Optional[Union[float, List[float]]] = None, 219 | do_convert_rgb: bool = None, 220 | return_tensors: Optional[Union[str, TensorType]] = None, 221 | data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, 222 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 223 | **kwargs, 224 | ) -> PIL.Image.Image: 225 | """ 226 | Preprocess an image or batch of images. 227 | 228 | Args: 229 | images (`ImageInput`): 230 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 231 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 232 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 233 | Whether to resize the image. 234 | size (`Dict[str, int]`, *optional*, defaults to `self.size`): 235 | Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with 236 | the longest edge resized to keep the input aspect ratio. 237 | resample (`int`, *optional*, defaults to `self.resample`): 238 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only 239 | has an effect if `do_resize` is set to `True`. 240 | do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): 241 | Whether to center crop the image. 242 | crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): 243 | Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. 244 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 245 | Whether to rescale the image. 246 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 247 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 248 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 249 | Whether to normalize the image. 250 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 251 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. 252 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 253 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to 254 | `True`. 255 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 256 | Whether to convert the image to RGB. 257 | return_tensors (`str` or `TensorType`, *optional*): 258 | The type of tensors to return. Can be one of: 259 | - Unset: Return a list of `np.ndarray`. 260 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 261 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 262 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 263 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 264 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 265 | The channel dimension format for the output image. Can be one of: 266 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 267 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 268 | - Unset: Use the channel dimension format of the input image. 269 | input_data_format (`ChannelDimension` or `str`, *optional*): 270 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 271 | from the input image. Can be one of: 272 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 273 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 274 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 275 | """ 276 | do_resize = do_resize if do_resize is not None else self.do_resize 277 | size = size if size is not None else self.size 278 | size = get_size_dict(size, param_name="size", default_to_square=False) 279 | resample = resample if resample is not None else self.resample 280 | do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop 281 | crop_size = crop_size if crop_size is not None else self.crop_size 282 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) 283 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 284 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 285 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 286 | image_mean = image_mean if image_mean is not None else self.image_mean 287 | image_std = image_std if image_std is not None else self.image_std 288 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 289 | 290 | validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) 291 | 292 | images = make_batched_images(images) 293 | 294 | if not valid_images(images): 295 | raise ValueError( 296 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 297 | "torch.Tensor, tf.Tensor or jax.ndarray." 298 | ) 299 | 300 | validate_preprocess_arguments( 301 | do_rescale=do_rescale, 302 | rescale_factor=rescale_factor, 303 | do_normalize=do_normalize, 304 | image_mean=image_mean, 305 | image_std=image_std, 306 | do_center_crop=do_center_crop, 307 | crop_size=crop_size, 308 | do_resize=do_resize, 309 | size=size, 310 | resample=resample, 311 | ) 312 | 313 | if do_convert_rgb: 314 | images = [self.blend_rgba(image) for image in images] 315 | 316 | # All transformations expect numpy arrays. 317 | images = [to_numpy_array(image) for image in images] 318 | 319 | if is_scaled_image(images[0]) and do_rescale: 320 | logger.warning_once( 321 | "It looks like you are trying to rescale already rescaled images. If the input" 322 | " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." 323 | ) 324 | 325 | if input_data_format is None: 326 | # We assume that all images have the same channel dimension format. 327 | input_data_format = infer_channel_dimension_format(images[0]) 328 | 329 | if do_resize: 330 | images = [ 331 | self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) 332 | for image in images 333 | ] 334 | 335 | if do_center_crop: 336 | images = [ 337 | self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images 338 | ] 339 | 340 | if do_rescale: 341 | images = [ 342 | self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images 343 | ] 344 | 345 | if do_normalize: 346 | images = [ 347 | self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 348 | for image in images 349 | ] 350 | 351 | images = [ 352 | to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images 353 | ] 354 | 355 | data = {"pixel_values": images} 356 | return BatchFeature(data=data, tensor_type=return_tensors) 357 | 358 | def blend_rgba(self, image: ImageInput) -> ImageInput: 359 | """ 360 | Convert image to RGB by blending the transparency layer if it's in RGBA format. 361 | If image is not `PIL.Image`, it si simply returned without modifications. 362 | 363 | Args: 364 | image (`ImageInput`): 365 | Image to convert. 366 | """ 367 | 368 | if not isinstance(image, PIL.Image.Image): 369 | return image 370 | elif image.mode == "RGB": 371 | return image 372 | 373 | img_rgba = np.array(image.convert("RGBA")) 374 | 375 | # If there is no transparency layer, simple convert and return. 376 | if not (img_rgba[:, :, 3] < 255).any(): 377 | return image.convert("RGB") 378 | 379 | # There is a transparency layer, blend it with a white background. 380 | # Calculate the alpha proportion for blending. 381 | alpha = img_rgba[:, :, 3] / 255.0 382 | img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3] 383 | return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB") 384 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon/processing_chameleon.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Processor class for Chameleon. 17 | """ 18 | 19 | from typing import List, Optional, Union 20 | 21 | from transformers.feature_extraction_utils import BatchFeature 22 | from transformers.image_utils import ImageInput 23 | from transformers.processing_utils import ProcessorMixin 24 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 25 | from transformers.utils import TensorType 26 | 27 | 28 | class ChameleonProcessor(ProcessorMixin): 29 | r""" 30 | Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single 31 | processor. 32 | 33 | [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`]. 34 | See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information. 35 | 36 | Args: 37 | image_processor ([`ChameleonImageProcessor`]): 38 | The image processor is a required input. 39 | tokenizer ([`LlamaTokenizerFast`]): 40 | The tokenizer is a required input. 41 | image_seq_length (`int`, *optional*, defaults to 1024): 42 | Sequence length of one image embedding. 43 | image_token (`str`, *optional*, defaults to `""`): 44 | The special token used to indicate image in the text. 45 | """ 46 | 47 | attributes = ["image_processor", "tokenizer"] 48 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") 49 | image_processor_class = "ChameleonImageProcessor" 50 | 51 | def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): 52 | self.image_seq_length = image_seq_length 53 | self.image_token = image_token 54 | self.image_start_token = "" # fixed tokens for start and end, so can hardcode 55 | self.image_end_token = "" 56 | super().__init__(image_processor, tokenizer) 57 | 58 | def __call__( 59 | self, 60 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 61 | images: ImageInput = None, 62 | padding: Union[bool, str, PaddingStrategy] = False, 63 | truncation: Union[bool, str, TruncationStrategy] = None, 64 | max_length: int = None, 65 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 66 | return_for_text_completion: bool = False, 67 | ) -> BatchFeature: 68 | """ 69 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 70 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode 71 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to 72 | CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring 73 | of the above two methods for more information. 74 | 75 | Args: 76 | text (`str`, `List[str]`, `List[List[str]]`): 77 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 78 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 79 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 80 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 81 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 82 | tensor. Both channels-first and channels-last formats are supported. 83 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): 84 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 85 | index) among: 86 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 87 | sequence if provided). 88 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 89 | acceptable input length for the model if that argument is not provided. 90 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 91 | lengths). 92 | max_length (`int`, *optional*): 93 | Maximum length of the returned list and optionally padding length (see above). 94 | truncation (`bool`, *optional*): 95 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 96 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 97 | If set, will return tensors of a particular framework. Acceptable values are: 98 | 99 | - `'tf'`: Return TensorFlow `tf.constant` objects. 100 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 101 | - `'np'`: Return NumPy `np.ndarray` objects. 102 | - `'jax'`: Return JAX `jnp.ndarray` objects. 103 | 104 | Returns: 105 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 106 | 107 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 108 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 109 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 110 | `None`). 111 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 112 | """ 113 | if isinstance(text, str): 114 | text = [text] 115 | elif not isinstance(text, list) and not isinstance(text[0], str): 116 | raise TypeError("Invalid input text. Please provide a string, or a list of strings") 117 | 118 | # Replace the image token with the expanded image token sequence 119 | prompt_strings = [] 120 | one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token 121 | for sample in text: 122 | sample = sample.replace(self.image_token, one_img_tokens) 123 | if not return_for_text_completion: 124 | sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode 125 | prompt_strings.append(sample) 126 | 127 | data = self.tokenizer( 128 | prompt_strings, 129 | return_tensors=return_tensors, 130 | padding=padding, 131 | truncation=truncation, 132 | max_length=max_length, 133 | ) 134 | 135 | if images is not None: 136 | pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] 137 | data["pixel_values"] = pixel_values 138 | 139 | return BatchFeature(data=data, tensor_type=return_tensors) 140 | 141 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama 142 | def batch_decode(self, *args, **kwargs): 143 | """ 144 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 145 | refer to the docstring of this method for more information. 146 | """ 147 | return self.tokenizer.batch_decode(*args, **kwargs) 148 | 149 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama 150 | def decode(self, *args, **kwargs): 151 | """ 152 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 153 | the docstring of this method for more information. 154 | """ 155 | return self.tokenizer.decode(*args, **kwargs) 156 | 157 | @property 158 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names 159 | def model_input_names(self): 160 | tokenizer_input_names = self.tokenizer.model_input_names 161 | image_processor_input_names = self.image_processor.model_input_names 162 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 163 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon_vae_ori/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_tokenizer import ImageTokenizer 2 | from .vocab import VocabInfo, VocabTranslation 3 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon_vae_ori/image_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates 2 | # 3 | # This source code is licensed under the Chameleon License found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import PIL 7 | from PIL import Image 8 | import numpy as np 9 | import torch 10 | import yaml 11 | 12 | from .vqgan import VQModel 13 | 14 | 15 | class ImageTokenizer: 16 | def __init__( 17 | self, 18 | cfg_path: str, 19 | ckpt_path: str, 20 | device: str | torch.device | None = None, 21 | ): 22 | with open(cfg_path) as f: 23 | config = yaml.safe_load(f) 24 | 25 | params = config["model"]["params"] 26 | if "lossconfig" in params: 27 | del params["lossconfig"] 28 | params["ckpt_path"] = ckpt_path 29 | 30 | self._vq_model = VQModel(**params) 31 | self._vq_model.eval() 32 | 33 | if device is None: 34 | devices = {p.device for p in self._vq_model.parameters()} 35 | assert len(devices) == 1 36 | device = devices.pop() 37 | else: 38 | self._vq_model.to(device) 39 | self._device = device 40 | 41 | dtypes = {p.dtype for p in self._vq_model.parameters()} 42 | assert len(dtypes) == 1 43 | self._dtype = dtypes.pop() 44 | 45 | def _whiten_transparency(self, img: PIL.Image) -> PIL.Image: 46 | # Check if it's already in RGB format. 47 | if img.mode == "RGB": 48 | return img 49 | 50 | vals_rgba = np.array(img.convert("RGBA")) 51 | 52 | # If there is no transparency layer, simple convert and return. 53 | if not (vals_rgba[:, :, 3] < 255).any(): 54 | return img.convert("RGB") 55 | 56 | # There is a transparency layer, blend it with a white background. 57 | 58 | # Calculate the alpha proportion for blending. 59 | alpha = vals_rgba[:, :, 3] / 255.0 60 | # Blend with white background. 61 | vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * vals_rgba[:, :, :3] 62 | return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB") 63 | 64 | # def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor: 65 | # # Resize with aspect ratio preservation. 66 | # s = min(img.size) 67 | # scale = target_image_size / s 68 | # new_size = (round(scale * img.size[0]), round(scale * img.size[1])) 69 | # img = img.resize(new_size, PIL.Image.LANCZOS) 70 | # 71 | # # Center crop. 72 | # x0 = (img.width - target_image_size) // 2 73 | # y0 = (img.height - target_image_size) // 2 74 | # img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size)) 75 | # 76 | # # Convert to tensor. 77 | # np_img = np.array(img) / 255.0 # Normalize to [0, 1] 78 | # np_img = np_img * 2 - 1 # Scale to [-1, 1] 79 | # tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() # (Channels, Height, Width) format. 80 | # 81 | # # Add batch dimension. 82 | # return tensor_img.unsqueeze(0) 83 | 84 | def img_tokens_from_pil(self, img: PIL.Image) -> list[int]: 85 | img = self._whiten_transparency(img) 86 | # Convert to tensor. 87 | np_img = np.array(img) / 255.0 # Normalize to [0, 1] 88 | np_img = np_img * 2 - 1 # Scale to [-1, 1] 89 | img = torch.from_numpy(np_img).permute(2, 0, 1).to(self._vq_model.encoder.conv_in.weight) 90 | img = img.unsqueeze(0) 91 | 92 | _, _, [_, _, img_toks] = self._vq_model.encode(img) 93 | return img_toks 94 | 95 | def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image: 96 | # Ensure detachment and move tensor to CPU. 97 | detached_chw_tensor = chw_tensor.detach().cpu() 98 | 99 | # Normalize tensor to [0, 1] range from [-1, 1] range. 100 | normalized_chw_tensor = (torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0) / 2.0 101 | 102 | # Permute CHW tensor to HWC format and convert to NumPy array. 103 | hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy() 104 | 105 | # Convert to an 8-bit unsigned integer format. 106 | image_array_uint8 = (hwc_array * 255).astype(np.uint8) 107 | 108 | # Convert NumPy array to PIL Image. 109 | pil_image = Image.fromarray(image_array_uint8) 110 | 111 | # Convert image to RGB if it is not already. 112 | if pil_image.mode != "RGB": 113 | pil_image = pil_image.convert("RGB") 114 | 115 | return pil_image 116 | 117 | def pil_from_img_toks(self, tokens: torch.Tensor, h_latent_dim=32, w_latent_dim=32) -> PIL.Image: 118 | emb_dim = self._vq_model.quantize.embedding.weight.shape[-1] 119 | codebook_entry = self._vq_model.quantize.get_codebook_entry(tokens, (1, h_latent_dim, w_latent_dim, emb_dim)) 120 | pixels = self._vq_model.decode(codebook_entry) 121 | return self._pil_from_chw_tensor(pixels[0]) 122 | 123 | def latent_embedding_from_pil(self, img: PIL.Image): 124 | img = self._whiten_transparency(img) 125 | 126 | # Convert to tensor. 127 | np_img = np.array(img) / 255.0 # Normalize to [0, 1] 128 | np_img = np_img * 2 - 1 # Scale to [-1, 1] 129 | img = torch.from_numpy(np_img).permute(2, 0, 1) # (Channels, Height, Width) format. 130 | img = img.unsqueeze(0).to(self._vq_model.encoder.conv_in.weight) 131 | latent_embedding, _, _ = self._vq_model.encode(img) 132 | return latent_embedding 133 | -------------------------------------------------------------------------------- /lumina_mgpt/model/chameleon_vae_ori/vocab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Chameleon License found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import cached_property 7 | 8 | import torch 9 | 10 | 11 | class VocabInfo: 12 | def __init__(self, vocab_map: dict[str, int]): 13 | self.name2val = vocab_map 14 | 15 | self.bos_id = vocab_map.get("") 16 | self.eos_id = vocab_map.get("") 17 | self.boi_id = vocab_map.get("") 18 | self.eoi_id = vocab_map.get("") 19 | self.pad_id = vocab_map.get("") 20 | self.eot_id = vocab_map.get("") 21 | 22 | @property 23 | def begin_sequence(self) -> int: 24 | return self.bos_id 25 | 26 | @property 27 | def end_sequence(self) -> int: 28 | return self.eos_id 29 | 30 | @property 31 | def begin_image(self) -> int: 32 | return self.boi_id 33 | 34 | @property 35 | def end_image(self) -> int: 36 | return self.eoi_id 37 | 38 | @property 39 | def padding(self) -> int: 40 | return self.pad_id 41 | 42 | @property 43 | def end_turn(self) -> int: 44 | return self.eot_id 45 | 46 | @cached_property 47 | def val2name(self) -> dict[int, str]: 48 | return {v: k for k, v in self.name2val.items()} 49 | 50 | @cached_property 51 | def all_tokens(self) -> list[int]: 52 | return sorted(self.name2val.values()) 53 | 54 | @cached_property 55 | def image_tokens(self) -> list[int]: 56 | return sorted([val for name, val in self.name2val.items() if name.startswith("IMGIMG")]) 57 | 58 | @cached_property 59 | def special_tokens(self) -> list[int]: 60 | return sorted([val for name, val in self.name2val.items() if name.startswith("<") and name != "<"]) 61 | 62 | @cached_property 63 | def text_tokens(self) -> list[int]: 64 | return sorted(set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)) 65 | 66 | 67 | class VocabTranslation: 68 | def __init__(self, vocab_info: VocabInfo, device: str | None = None): 69 | self._vocab = vocab_info 70 | self._device = device 71 | 72 | @cached_property 73 | def bpe2img(self) -> dict[int, int]: 74 | img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} 75 | 76 | def remap(old_name: str) -> str: 77 | return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) 78 | 79 | return {tok: int(remap(self._vocab.val2name[tok])) for tok in self._vocab.image_tokens} 80 | 81 | @cached_property 82 | def img2bpe(self) -> dict[int, int]: 83 | return {v: k for k, v in self.bpe2img.items()} 84 | 85 | @cached_property 86 | def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: 87 | sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device) 88 | sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device) 89 | return sorted_bpe, sorted_img 90 | 91 | @cached_property 92 | def img2bpe_mapping_tensor(self) -> torch.LongTensor: 93 | mapping = torch.zeros( 94 | max(self.img2bpe.keys()) + 1, 95 | dtype=torch.int, 96 | device=self._device, 97 | ) 98 | for k, v in self.img2bpe.items(): 99 | mapping[k] = v 100 | return mapping 101 | 102 | def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor: 103 | bpe_tok, img_tok = self.bpe2img_search_tensors 104 | return img_tok[torch.searchsorted(bpe_tok, bpe_batch)] 105 | 106 | def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor: 107 | return self.img2bpe_mapping_tensor[img_batch] 108 | -------------------------------------------------------------------------------- /lumina_mgpt/model/configuration_xllmx_chameleon.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from .chameleon import ChameleonConfig 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class ChameleonXLLMXConfig(ChameleonConfig): 10 | 11 | def __init__( 12 | self, 13 | z_loss_weight: float = 0.0, 14 | **kwargs, 15 | ): 16 | self.z_loss_weight = z_loss_weight 17 | super().__init__( 18 | **kwargs, 19 | ) 20 | -------------------------------------------------------------------------------- /lumina_mgpt/model/modeling_xllmx_chameleon.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import math 4 | from typing import List 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from .chameleon import ChameleonForConditionalGeneration 10 | from .configuration_xllmx_chameleon import ChameleonXLLMXConfig 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) 15 | 16 | 17 | __all__ = ["ChameleonXLLMXForConditionalGeneration"] 18 | 19 | 20 | class ChameleonXLLMXForConditionalGeneration(ChameleonForConditionalGeneration): 21 | config_class = ChameleonXLLMXConfig 22 | 23 | def __init__(self, config): 24 | super().__init__(config) 25 | 26 | def forward(self, input_ids=None, labels=None, training=True, **kwargs): 27 | 28 | max_tokens = max([len(_) for _ in input_ids]) 29 | max_tokens = min(max_tokens, self.config.max_position_embeddings) 30 | input_ids = [_[:max_tokens] for _ in input_ids] 31 | labels = [_[:max_tokens] for _ in labels] 32 | 33 | input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] 34 | input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) 35 | 36 | labels = [label + [-100] * (max_tokens - len(label)) for label in labels] 37 | labels = torch.tensor(labels, dtype=torch.int64, device=self.device) 38 | 39 | # explicit use_cache=False for the following 40 | # https://github.com/Lightning-AI/pytorch-lightning/issues/19267 41 | result = ChameleonForConditionalGeneration.forward( 42 | self, input_ids=input_ids, labels=labels, use_cache=False, **kwargs 43 | ) 44 | 45 | c_loss = result[0] 46 | 47 | additional_loss_dict = {} 48 | if self.config.z_loss_weight > 0: 49 | logits: torch.Tensor = result[1] 50 | shift_logits = logits[..., :-1, :].contiguous() 51 | shift_labels = labels[..., 1:].contiguous() 52 | valid_mask = shift_labels >= 0 53 | z_loss = torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean() 54 | additional_loss_dict["z_loss"] = (z_loss, self.config.z_loss_weight) 55 | return c_loss, additional_loss_dict 56 | 57 | def get_fsdp_wrap_module_list(self) -> List: 58 | modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens] 59 | if hasattr(self.model, "vqmodel"): # may be deleted 60 | modules.append(self.model.vqmodel) 61 | return modules 62 | 63 | def get_checkpointing_wrap_module_list(self) -> List: 64 | modules = [ 65 | *list(self.model.layers), 66 | ] 67 | return modules 68 | -------------------------------------------------------------------------------- /lumina_mgpt/pre_tokenize/concat_record.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import os 4 | import re 5 | import warnings 6 | 7 | 8 | def find_sub_records(directory: str): 9 | pattern = re.compile(r"\d+-of-\d+-record\.json(l)?") 10 | 11 | sub_record_files = [f for f in os.listdir(directory) if pattern.match(f)] 12 | sorted_files = sorted(sub_record_files, key=lambda filename: int(filename.split("-of")[0])) 13 | return sorted_files 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = ArgumentParser() 18 | parser.add_argument( 19 | "--sub_record_dir", 20 | type=str, 21 | default=None, 22 | ) 23 | parser.add_argument( 24 | "--save_path", 25 | type=str, 26 | default=None, 27 | ) 28 | args = parser.parse_args() 29 | 30 | l_sub_records = find_sub_records(args.sub_record_dir) 31 | 32 | print(f"find {len(l_sub_records)} sub-records in {args.sub_record_dir}") 33 | print(str(l_sub_records) + "\n\n") 34 | 35 | complete_record = [] 36 | for sub_record in l_sub_records: 37 | with open(os.path.join(args.sub_record_dir, sub_record)) as f: 38 | lines = f.readlines() 39 | for i, l in enumerate(lines): 40 | try: 41 | l_item = json.loads(l) 42 | complete_record.append(l_item) 43 | except: 44 | if i == len(lines) - 1: 45 | print(f"{sub_record} seems still writing, skip last incomplete record") 46 | else: 47 | warnings.warn(f"read line failed: {l}") 48 | 49 | with open(args.save_path, "w") as f: 50 | json.dump(complete_record, f) 51 | -------------------------------------------------------------------------------- /lumina_mgpt/pre_tokenize/pre_tokenize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0]) 5 | 6 | from argparse import ArgumentParser 7 | import json 8 | import math 9 | import pickle 10 | 11 | from data.convertsation import Conversation 12 | from data.item_processor import FlexARItemProcessor 13 | 14 | 15 | class ItemProcessor(FlexARItemProcessor): 16 | def __init__( 17 | self, 18 | tokenizer="Alpha-VLLM/Lumina-mGPT-7B-768", 19 | conv_template=Conversation, 20 | target_size=512, 21 | ): 22 | super().__init__(tokenizer, conv_template, target_size) 23 | print(self.crop_size_list) 24 | 25 | def process_item(self, raw_item, training_mode=False, out_flatten=True): 26 | 27 | # Add custom codes here to convert raw_item to the standard format 28 | # The standard format contains the "conversations" and "image" keys 29 | 30 | # ********* Add your custom codes here ******* 31 | 32 | # ********* Add your custom codes here ******* 33 | 34 | item = { 35 | "conversations": raw_item["conversations"], 36 | "image": raw_item["image"], 37 | } 38 | 39 | return super(ItemProcessor, self).process_item(item, training_mode, out_flatten) 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | parser = ArgumentParser() 45 | parser.add_argument( 46 | "--splits", 47 | type=int, 48 | default=8, 49 | ) 50 | parser.add_argument( 51 | "--rank", 52 | type=int, 53 | default=0, 54 | ) 55 | parser.add_argument( 56 | "--in_filename", 57 | type=str, 58 | ) 59 | parser.add_argument( 60 | "--out_dir", 61 | type=str, 62 | ) 63 | parser.add_argument("--target_size", type=int, default=512) 64 | args = parser.parse_args() 65 | 66 | item_processor = ItemProcessor(target_size=args.target_size) 67 | 68 | with open(args.in_filename) as f: 69 | ori_contents = json.load(f) 70 | 71 | num = len(ori_contents) 72 | 73 | splits = args.splits 74 | rank = args.rank 75 | output_dir = args.out_dir 76 | save_dir = os.path.join(output_dir, "files") 77 | os.makedirs(save_dir, exist_ok=True) 78 | 79 | num_per_rank = math.ceil(num / splits) 80 | 81 | try: 82 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "r") as f: 83 | start_idx = int(f.read()) + 1 84 | print(f"resume from {start_idx}") 85 | except: 86 | start_idx = num_per_rank * rank 87 | print(f"start from {start_idx}") 88 | 89 | end_idx = min(num_per_rank * (rank + 1), len(ori_contents)) 90 | for i in range(start_idx, end_idx): 91 | if i % 10 == 0: 92 | print(f"{i}/{end_idx}") 93 | 94 | record = None 95 | pkl_path = os.path.join(save_dir, f"{i}.pkl") 96 | try: 97 | tokens, labels = item_processor.process_item(ori_contents[i], training_mode=True) 98 | new_item = {"token": tokens, "label": labels, "id": i} 99 | with open(pkl_path, "wb") as f: 100 | pickle.dump(new_item, f) 101 | 102 | record = {"file": pkl_path, "len": len(tokens), "id": i} 103 | 104 | except Exception as e: 105 | from traceback import format_exc 106 | 107 | print(f"item {i} error: \n{ori_contents[i]}") 108 | print(format_exc()) 109 | 110 | if record is not None: 111 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-record.jsonl"), "a") as f: 112 | record_str = json.dumps(record) + "\n" 113 | f.write(record_str) 114 | 115 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "w") as f: 116 | if i == end_idx - 1: 117 | f.write("finished") 118 | else: 119 | f.write(f"{i}") 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | torchvision==0.18.0 3 | torchaudio==2.3.0 4 | pandas 5 | tensorboard 6 | fairscale 7 | sentencepiece 8 | gradio==4.19.0 9 | packaging 10 | transformers>=4.43.3 11 | pyyaml 12 | pathlib 13 | Ninja 14 | bitsandbytes 15 | httpx[socks] 16 | einops 17 | regex 18 | h5py 19 | accelerate 20 | pre-commit 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="xllmx", 8 | version="0.0.1", 9 | author="Alpha-VLLM", 10 | description="An Open-source Toolkit for LLM-centered Any2Any Generation", 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/Alpha-VLLM/Lumina-mGPT", 14 | packages=["xllmx"], 15 | include_package_data=True, 16 | ) 17 | -------------------------------------------------------------------------------- /xllmx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/__init__.py -------------------------------------------------------------------------------- /xllmx/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/data/__init__.py -------------------------------------------------------------------------------- /xllmx/data/conversation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/data/conversation/__init__.py -------------------------------------------------------------------------------- /xllmx/data/conversation/template.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class ConversationBase: 5 | roles = ["Human", "Assistant"] 6 | 7 | def __init__(self, messages=None): 8 | self.messages = messages or [] 9 | 10 | def process(self): 11 | raise NotImplementedError 12 | 13 | def get_prompt(self): 14 | return self.process()["conv"] 15 | 16 | def append_message(self, role, message): 17 | self.messages.append([role, message]) 18 | 19 | def copy(self): 20 | return ConversationBase( 21 | messages=[[x, y] for x, y in self.messages], 22 | ) 23 | 24 | def load_qas(self, qas: List[List[str]]): 25 | self.messages = [] 26 | for q, a in qas: 27 | self.append_message(self.roles[0], q) 28 | self.append_message(self.roles[1], a) 29 | -------------------------------------------------------------------------------- /xllmx/data/data_reader.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import logging 3 | import time 4 | from typing import Union 5 | 6 | from PIL import Image 7 | 8 | Image.MAX_IMAGE_PIXELS = None 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def read_general(path) -> Union[str, BytesIO]: 13 | if "s3://" in path: 14 | init_ceph_client_if_needed() 15 | file_bytes = BytesIO(client.get(path)) 16 | return file_bytes 17 | else: 18 | return path 19 | 20 | 21 | def init_ceph_client_if_needed(): 22 | global client 23 | if client is None: 24 | logger.info(f"initializing ceph client ...") 25 | st = time.time() 26 | from petrel_client.client import Client # noqa 27 | 28 | client = Client("/path/to/petreloss.conf") 29 | ed = time.time() 30 | logger.info(f"initialize client cost {ed - st:.2f} s") 31 | 32 | 33 | client = None 34 | -------------------------------------------------------------------------------- /xllmx/data/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import os 5 | from pathlib import Path 6 | import pickle 7 | from time import sleep 8 | import traceback 9 | import warnings 10 | 11 | import h5py 12 | import torch 13 | import torch.distributed as dist 14 | from torch.utils.data import Dataset 15 | import yaml 16 | 17 | from .item_processor import ItemProcessorBase 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class FinetuneConversationDataset(Dataset): 23 | def __init__(self, config_path, item_processor: ItemProcessorBase, cache_on_disk=False): 24 | 25 | self.item_processor = item_processor 26 | 27 | logger.info(f"read dataset config from {config_path}") 28 | with open(config_path, "r") as f: 29 | self.config = yaml.load(f, Loader=yaml.FullLoader) 30 | logger.info("DATASET CONFIG:") 31 | logger.info(self.config) 32 | 33 | self.cache_on_disk = cache_on_disk 34 | if self.cache_on_disk: 35 | cache_dir = self._get_cache_dir(config_path) 36 | if dist.get_rank() == 0: 37 | self._collect_annotations_and_save_to_cache(cache_dir) 38 | dist.barrier() 39 | self.meta_collection, self.annotations_collection = self._load_annotations_from_cache(cache_dir) 40 | else: 41 | cache_dir = None 42 | self.meta_collection, self.annotations_collection = self._collect_annotations() 43 | 44 | def __len__(self): 45 | return sum([_["len"] for _ in self.meta_collection]) 46 | 47 | def _collect_annotations(self): 48 | meta_collection = [] 49 | annotations_collection = [] 50 | 51 | for meta in self.config["META"]: 52 | meta, annotations = self._load_meta(meta) 53 | meta_collection.append(meta) 54 | annotations_collection.append(annotations) 55 | 56 | return meta_collection, annotations_collection 57 | 58 | def _load_meta(self, meta): 59 | if "type" not in meta: 60 | meta["type"] = "default" 61 | 62 | meta_path, meta_type = meta["path"], meta["type"] 63 | meta_ext = os.path.splitext(meta_path)[-1] 64 | if meta_ext == ".json": 65 | with open(meta_path) as f: 66 | annotations = json.load(f) 67 | elif meta_ext == ".jsonl": 68 | annotations = [] 69 | with open(meta_path) as f: 70 | for i, line in enumerate(f): 71 | try: 72 | annotations.append(json.loads(line)) 73 | except json.decoder.JSONDecodeError as e: 74 | logger.error(f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}") 75 | raise e 76 | elif meta_ext == ".pkl": 77 | with open(meta_path, "rb") as f: 78 | annotations = pickle.load(f) 79 | assert isinstance(annotations, list) 80 | elif meta_ext == ".pth": 81 | annotations = torch.load(meta_path) 82 | assert isinstance(annotations, list) 83 | else: 84 | raise NotImplementedError( 85 | f'Unknown meta file extension: "{meta_ext}". ' 86 | f"Currently, .json, .jsonl are supported. " 87 | "If you are using a supported format, please set the file extension so that the proper parsing " 88 | "routine can be called." 89 | ) 90 | logger.info(f"{meta_path}, type{meta_type}: len {len(annotations)}") 91 | 92 | meta["len"] = len(annotations) 93 | 94 | meta["item_len_list"] = [self.item_processor.predict_item_token_length(_) for _ in annotations] 95 | 96 | return meta, annotations 97 | 98 | def _collect_annotations_and_save_to_cache(self, cache_dir): 99 | if (Path(cache_dir) / "data.h5").exists() and (Path(cache_dir) / "ready").exists(): 100 | # off-the-shelf annotation cache exists 101 | warnings.warn( 102 | f"Use existing h5 data cache: {Path(cache_dir)}\n" 103 | f"Note: if the actual data defined by the data config has changed since your last run, " 104 | f"please delete the cache manually and re-run this experiment, or the data actually used " 105 | f"will not be updated" 106 | ) 107 | return 108 | 109 | Path(cache_dir).mkdir(parents=True, exist_ok=True) 110 | meta_collection, annotations_collection = self._collect_annotations() 111 | 112 | # when cache on disk, rank0 saves items to an h5 file 113 | logger.info(f"start to build data cache to: {Path(cache_dir)}") 114 | with h5py.File(Path(cache_dir) / "data.h5", "w") as file: 115 | dt = h5py.vlen_dtype(str) 116 | for i, annotations in enumerate(annotations_collection): 117 | serialized_ann = [json.dumps(_) for _ in annotations] 118 | h5_ann = file.create_dataset(f"ann{i}", (len(serialized_ann),), dtype=dt) 119 | h5_ann[:] = serialized_ann 120 | 121 | file.create_dataset("meta_collection", data=json.dumps(meta_collection)) 122 | with open(Path(cache_dir) / "ready", "w") as f: 123 | f.write("ready") 124 | logger.info(f"data cache built") 125 | 126 | @staticmethod 127 | def _get_cache_dir(config_path): 128 | config_identifier = config_path 129 | disallowed_chars = ["/", "\\", ".", "?", "!"] 130 | for _ in disallowed_chars: 131 | config_identifier = config_identifier.replace(_, "-") 132 | cache_dir = f"./xllmx_data_cache/{config_identifier}" 133 | return cache_dir 134 | 135 | @staticmethod 136 | def _load_annotations_from_cache(cache_dir): 137 | while not (Path(cache_dir) / "ready").exists(): 138 | # cache has not yet been completed by rank 0 139 | assert dist.get_rank() != 0 140 | sleep(1) 141 | cache_file = h5py.File(Path(cache_dir) / "data.h5", "r") 142 | meta_collection = json.loads(cache_file["meta_collection"].asstr()[()]) 143 | annotations_collection = [cache_file[f"ann{i}"] for i in range(len(meta_collection))] 144 | return meta_collection, annotations_collection 145 | 146 | def get_item_func(self, meta_idx, idx_in_meta): 147 | data_item = self.annotations_collection[meta_idx][idx_in_meta] 148 | if self.cache_on_disk: 149 | data_item = json.loads(data_item) 150 | else: 151 | data_item = copy.deepcopy(data_item) 152 | 153 | return self.item_processor.process_item(data_item, training_mode=True) 154 | 155 | def tie_index_to_meta(self, idx: int): 156 | # Initialize the starting index 157 | start_idx = 0 158 | 159 | # Iterate through the list of dictionaries 160 | for i, meta in enumerate(self.meta_collection): 161 | # Calculate the ending index for the current collection 162 | end_idx = start_idx + meta["len"] 163 | 164 | # Check if the given index falls within the current collection 165 | if start_idx <= idx < end_idx: 166 | # Calculate the new index within the current collection 167 | new_index = idx - start_idx 168 | return i, new_index 169 | 170 | # Update the starting index for the next collection 171 | start_idx = end_idx 172 | 173 | # If the index is out of range of all collections, raise an error 174 | raise IndexError("Index out of range") 175 | 176 | def __getitem__(self, index): 177 | meta_idx, idx_in_meta = self.tie_index_to_meta(index) 178 | 179 | try: 180 | return self.get_item_func(meta_idx, idx_in_meta) 181 | except Exception as e: 182 | logger.info( 183 | f"Item {index} errored, annotation:\n" 184 | f"{self.annotations_collection[meta_idx][idx_in_meta]}\n" 185 | f"Error:\n" 186 | f"{traceback.format_exc()}" 187 | ) 188 | if idx_in_meta != 0: 189 | return self[index - 1] 190 | else: 191 | return self[index + self.meta_collection[meta_idx]["len"] - 1] 192 | -------------------------------------------------------------------------------- /xllmx/data/item_processor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import copy 3 | import logging 4 | from typing import Any, Callable, Dict, List, Tuple, Union 5 | 6 | from xllmx.model.tokenizer import Tokenizer 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class LabelAllZeroError(Exception): 12 | def __init__(self, message=None): 13 | self.message = message 14 | 15 | def __str__(self): 16 | return f"LabelAllZeroError: {self.message}" 17 | 18 | 19 | class ItemProcessorBase(ABC): 20 | @abstractmethod 21 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]: 22 | raise NotImplementedError 23 | 24 | def predict_item_token_length(self, data_item: dict) -> int: 25 | """ 26 | estimate the token length of the data item for gathering items of similar lengths into a batch 27 | """ 28 | return 1 29 | 30 | 31 | class MMConvItemProcessor(ItemProcessorBase): 32 | def __init__( 33 | self, 34 | transform: Dict[str, Callable[[Any], Dict]], 35 | media_symbols: List[str], 36 | tokenizer: str | Tokenizer, 37 | conv_template, 38 | ): 39 | self.transform = transform 40 | logger.info(f"transform:\n{self.transform}") 41 | 42 | self.media_symbols = media_symbols 43 | logger.info(f"media_symbols:\n{self.media_symbols}") 44 | 45 | if isinstance(tokenizer, str): 46 | self.tokenizer = Tokenizer(model_path=tokenizer) 47 | else: 48 | self.tokenizer = copy.deepcopy(tokenizer) 49 | 50 | # todo should not already exist 51 | self.tokenizer.tokenizer.add_tokens(media_symbols) 52 | self.d_media_symbol2token = {} 53 | self.d_media_token2symbol = {} 54 | for media_symbol in media_symbols: 55 | tokenized_symbol = self.tokenizer.encode(media_symbol, bos=False, eos=False) 56 | assert len(tokenized_symbol) == 1 57 | self.d_media_symbol2token[media_symbol] = tokenized_symbol[0] 58 | self.d_media_token2symbol[tokenized_symbol[0]] = media_symbol 59 | 60 | # implicit_at_beginning means media without explict location specification are arranged right after bos token 61 | # if false, then these medias are arranged at the beginning of the first question 62 | self.implicit_at_beginning = False 63 | self.conv_template = conv_template 64 | 65 | def collect_and_process_media(self, data_item): 66 | """ 67 | this function receives a raw piece of data (e.g. read from `.json` data file), 68 | and returns d_media, containing the prepared media readily usable by model 69 | YOU MAY OVERRIDE THIS FUNCTION TO SUPPORT COMPLEX LOADING OF VARIOUS FORMS OF DATA 70 | """ 71 | d_media = {} 72 | for media_symbol in self.media_symbols: 73 | if media_symbol in data_item: 74 | l_media = data_item[media_symbol] # a list of media paths 75 | elif media_symbol.lstrip("<|").rstrip("|>") in data_item: 76 | l_media = data_item[media_symbol.lstrip("<|").rstrip("|>")] 77 | else: 78 | l_media = [] 79 | if not isinstance(l_media, list): # data with only one media, in format {"image": image_name, ...} 80 | l_media = [l_media] 81 | 82 | d_media[media_symbol] = [] 83 | for media in l_media: 84 | media = self.transform[media_symbol](media) 85 | assert isinstance(media, Dict) 86 | media["type"] = media_symbol 87 | d_media[media_symbol].append(media) 88 | 89 | return d_media 90 | 91 | def replace_media_token_with_media( 92 | self, tokens: List[int], labels: Union[List[int], None], d_media: Dict[str, List] 93 | ): 94 | d_media_counter = {key: 0 for key in d_media} 95 | for i, t in enumerate(tokens): 96 | if t in self.d_media_token2symbol: 97 | media_symbol = self.d_media_token2symbol[t] 98 | media = d_media[media_symbol][d_media_counter[media_symbol]] 99 | d_media_counter[media_symbol] += 1 100 | tokens[i] = media 101 | media["to_predict"] = labels[i] > 0 102 | 103 | assert all([d_media_counter[key] == len(d_media[key]) for key in d_media]) 104 | 105 | if labels is not None: 106 | return tokens, labels 107 | else: 108 | return tokens 109 | 110 | @staticmethod 111 | def insert_implicit_media_symbol_in_q1(conv_list: List[Dict], d_media: Dict): 112 | """ 113 | Add the media tokens to the beginning of the first instruction from 114 | human. This logic may be more reasonable. However, it is incompatible 115 | with old-version Accessory models, which are trained with image tokens 116 | inserted directly behind the first token (). 117 | :param conv_list: [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}, ...] 118 | :param d_media: a dict of media for all media types 119 | """ 120 | conv_list = copy.deepcopy(conv_list) 121 | 122 | for media_symbol, l_media in d_media.items(): 123 | media_symbol_count = "".join([_["value"] for _ in conv_list if _["value"] is not None]).count(media_symbol) 124 | if media_symbol_count > 0: 125 | assert media_symbol_count == len( 126 | l_media 127 | ), f"{media_symbol_count} {media_symbol} exists in text, but {len(l_media)} actual media are given" 128 | else: 129 | conv_list[0]["value"] = (media_symbol + " ") * len(l_media) + conv_list[0]["value"] 130 | 131 | return conv_list 132 | 133 | @staticmethod 134 | def insert_implicit_media_symbol_at_beginning(conv: str, d_media: Dict): 135 | """ 136 | Legacy versions of LLaMA2-Accessory handled media in a non-interleaved 137 | manner, where image tokens are inserted directly behind the first token, 138 | namely . To support interleaved media comprehension and generation, 139 | Accessory now supports the explicit specification of media occurrence, 140 | which is achieved by adding media symbols, e.g. , within the 141 | conversations. On the other hand, for media without explicit 142 | specification, this function realizes the legacy behavior to arrange 143 | them at the beginning of the conversation. 144 | :param conv: conversation 145 | :param d_media: a dict of media for all media types, for determining how 146 | many media tokens need to be inserted 147 | """ 148 | conv = copy.deepcopy(conv) 149 | 150 | for media_symbol, l_media in d_media.items(): 151 | media_symbol_count = conv.count(media_symbol) 152 | if media_symbol_count > 0: 153 | assert media_symbol_count == len( 154 | l_media 155 | ), f"{media_symbol_count} {media_symbol} exists in text, but {len(l_media)} actual media are given" 156 | else: 157 | conv = (media_symbol + " ") * len(l_media) + conv 158 | 159 | return conv 160 | 161 | def preprocess_item(self, data_item): 162 | return data_item 163 | 164 | def add_speaker_and_signal(self, source: List): 165 | """ 166 | Given source instruction and response pieces, return the text containing the complete conversation, 167 | and the list of values that the model should learn to predict during training 168 | :param source: [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}, ...] 169 | :return: `conversation`: string containing the complete conversation; 170 | `to_predict_list`: the list of values that the model should learn to predict during training 171 | """ 172 | conv = self.conv_template() 173 | 174 | for i, sentence in enumerate(source): 175 | from_str = sentence["from"] 176 | if i % 2 == 0: 177 | assert from_str.lower() in ["human"] 178 | role = conv.roles[0] 179 | elif i % 2 == 1: 180 | assert from_str.lower() in ["gpt", "assistant"] 181 | role = conv.roles[1] 182 | else: 183 | raise ValueError(f"unknown dialog role: {from_str.lower()}") 184 | 185 | value = sentence["value"] 186 | 187 | conv.append_message(role, value) 188 | 189 | processed = conv.process() 190 | conversation, pieces = processed["conv"], processed["pieces"] 191 | 192 | return conversation, pieces 193 | 194 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]: 195 | data_item = self.preprocess_item(data_item) 196 | 197 | d_media = self.collect_and_process_media(data_item) 198 | 199 | source = data_item["conversations"] 200 | 201 | # implicit_at_beginning means media without explict location specification are arranged right after bos token 202 | # if false, then these medias are arranged at the beginning of the first question 203 | if not self.implicit_at_beginning: 204 | source = self.insert_implicit_media_symbol_in_q1(source, d_media) 205 | 206 | conversation, pieces = self.add_speaker_and_signal(source) 207 | 208 | if self.implicit_at_beginning: 209 | conversation = self.insert_implicit_media_symbol_at_beginning(conversation, d_media) 210 | 211 | # dialog does not need eos 212 | tokens = self.tokenizer.encode(conversation, bos=True, eos=False) 213 | labels = [-100 for _ in tokens] 214 | 215 | # check special token num as expected 216 | for media_symbol, l_media in d_media.items(): 217 | media_token = self.d_media_symbol2token[media_symbol] 218 | media_token_count = tokens.count(media_token) 219 | assert media_token_count == len(l_media), ( 220 | f"{media_token_count} {media_token} (for {media_symbol}) exists in tokenized conversation, " 221 | f"but {len(l_media)} actual media are given" 222 | ) 223 | 224 | check_pos = 0 225 | for i, p in enumerate(pieces): 226 | if i == 0: 227 | tokenized_value = self.tokenizer.encode(p["data"], bos=True, eos=False) 228 | else: 229 | tokenized_value = self.tokenizer.encode_wo_prefix_space(p["data"]) 230 | 231 | assert ( 232 | tokens[check_pos : check_pos + len(tokenized_value)] == tokenized_value 233 | ), "inconsistent complete conversation and corresponding piece after tokenization" 234 | 235 | if p["predict"]: 236 | labels[check_pos : check_pos + len(tokenized_value)] = tokenized_value 237 | 238 | check_pos = check_pos + len(tokenized_value) 239 | 240 | if training_mode and all([_ <= 0 for _ in labels]): # nothing to predict 241 | raise LabelAllZeroError() 242 | 243 | # labels will be processed later by the model 244 | tokens, labels = self.replace_media_token_with_media(tokens, labels, d_media) 245 | 246 | assert len(tokens) == len(labels) 247 | 248 | if training_mode: 249 | return tokens, labels 250 | else: 251 | return tokens 252 | 253 | def predict_item_token_length(self, data_item: dict) -> int: 254 | """ 255 | estimate the length of each item 256 | """ 257 | 258 | if "conversations" in data_item: 259 | return sum([len(_["value"]) for _ in data_item["conversations"]]) 260 | else: 261 | return 1 262 | -------------------------------------------------------------------------------- /xllmx/data/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | from typing import Iterator, List, Optional 4 | 5 | import numpy as np 6 | from torch.utils.data import Sampler 7 | 8 | from xllmx.data.dataset import FinetuneConversationDataset 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | # todo too slow to be used 14 | def mild_shuffle(items: List, shuffle_factor, engine: np.random.Generator): 15 | """ 16 | Perform a mild shuffle on the list of items. 17 | 18 | Args: 19 | engine: random engine 20 | items (list): The list of items to shuffle. 21 | shuffle_factor (float): max swap range is computed as len(item) * shuffle_factor. 22 | 23 | Returns: 24 | list: The mildly shuffled list. 25 | """ 26 | 27 | n = len(items) 28 | swap_range = int(shuffle_factor * n) 29 | shuffled_items = [None for _ in items] 30 | cache = list(range(swap_range)) 31 | for i in range(n): 32 | if i + swap_range < n: 33 | cache.append(i + swap_range) 34 | if len(cache) == 0 or cache[0] != i: # already swapped 35 | assert shuffled_items[i] is not None 36 | continue 37 | else: 38 | cache = cache[1:] 39 | if len(cache) == 0: 40 | shuffled_items[i] = items[i] 41 | else: 42 | cache_idx = engine.integers(low=0, high=len(cache)) 43 | j = cache[cache_idx] 44 | del cache[cache_idx] 45 | shuffled_items[i], shuffled_items[j] = items[j], items[i] 46 | 47 | return shuffled_items 48 | 49 | 50 | class FinetuneDistSampler(Sampler): 51 | def __init__( 52 | self, 53 | dataset: FinetuneConversationDataset, 54 | num_replicas: Optional[int] = None, 55 | rank: Optional[int] = None, 56 | shuffle: bool = True, 57 | seed: int = 0, 58 | batch_size=None, 59 | acc_grad=1, 60 | length_clustering=True, 61 | allow_mixed_task_among_acc=False, 62 | ): 63 | """ 64 | Distributed Sampler ensuring data in a batch are of the same type (e.g. text, image-text) 65 | :param dataset: 66 | :param num_replicas: 67 | :param rank: 68 | :param shuffle: 69 | :param seed: 70 | :param batch_size: 71 | :param acc_grad: 72 | :param length_clustering: 73 | :param allow_mixed_task_among_acc: 74 | """ 75 | # super().__init__() 76 | 77 | if num_replicas is None or rank is None or rank >= num_replicas or rank < 0: 78 | raise ValueError(f"Invalid num_replicas ({num_replicas}) or rank ({rank})") 79 | assert batch_size is not None 80 | 81 | self.dataset = dataset 82 | self.num_replicas = num_replicas 83 | self.rank = rank 84 | self.shuffle = shuffle 85 | self.seed = seed 86 | self.batch_size = batch_size 87 | self.acc_grad = acc_grad 88 | self.length_clustering = length_clustering 89 | self.allow_mixed_task_among_acc = allow_mixed_task_among_acc 90 | 91 | self.epoch = 0 92 | self.start_iter = 0 93 | 94 | global_bsz_acc = batch_size * num_replicas * acc_grad 95 | 96 | group_len = defaultdict(int) 97 | for i, meta in enumerate(dataset.meta_collection): 98 | group_len[meta["type"]] += int(meta["len"] * meta.get("ratio", 1.0)) 99 | 100 | group_len = {key: val // global_bsz_acc * global_bsz_acc for key, val in group_len.items()} 101 | 102 | self.total_size = sum(list(group_len.values())) 103 | assert self.total_size % num_replicas == 0 104 | self.num_samples = self.total_size // num_replicas 105 | 106 | def __iter__(self) -> Iterator: 107 | global_batch_size = self.batch_size * self.num_replicas 108 | global_bsz_acc = self.batch_size * self.num_replicas * self.acc_grad 109 | rng = np.random.default_rng(self.seed + self.epoch) 110 | 111 | group_indices_and_len = defaultdict(list) 112 | 113 | # Initialize the starting index 114 | start_idx = 0 115 | 116 | # Iterate through the list of dictionaries 117 | for i, meta in enumerate(self.dataset.meta_collection): 118 | # Calculate the ending index for the current collection 119 | end_idx = start_idx + meta["len"] 120 | indices = list(range(start_idx, end_idx)) 121 | assert len(indices) == len(meta["item_len_list"]) 122 | indices_and_len = [[idx, length] for idx, length in zip(indices, meta["item_len_list"])] 123 | if meta.get("ratio", 1.0) != 1.0: 124 | indices_and_len = list(rng.choice(indices_and_len, int(meta["len"] * meta["ratio"]), replace=False)) 125 | logger.info(f"meta{i}: sample (ratio = {meta['ratio']}) {len(indices_and_len)} items") 126 | group_indices_and_len[meta["type"]].extend(indices_and_len) 127 | 128 | # Update the starting index for the next collection 129 | start_idx = end_idx 130 | 131 | for group_name, indices_and_len in group_indices_and_len.items(): 132 | group_indices_and_len[group_name] = indices_and_len[ 133 | : len(indices_and_len) // global_bsz_acc * global_bsz_acc 134 | ] 135 | 136 | if self.shuffle: 137 | group_indices = {} 138 | if self.length_clustering: 139 | for group_name, indices_and_len in group_indices_and_len.items(): 140 | indices_and_len.sort(key=lambda x: x[1]) 141 | group_indices[group_name] = [_[0] for _ in indices_and_len] 142 | 143 | # option1: shuffle among neighboring items 144 | for group_name, indices in group_indices.items(): 145 | result = [] 146 | for pos in range(0, len(indices), global_batch_size * 500): 147 | sublist = indices[pos : pos + global_batch_size * 500] 148 | rng.shuffle(sublist) 149 | result.extend(sublist) 150 | group_indices[group_name] = result 151 | # option2: mild shuffle 152 | # group_indices[group_name] = mild_shuffle(indices, 0.1, rng) 153 | # option3: do nothing 154 | # pass 155 | else: 156 | for group_name, indices_and_len in group_indices_and_len.items(): 157 | rng.shuffle(indices_and_len) 158 | group_indices[group_name] = [_[0] for _ in indices_and_len] 159 | 160 | del group_indices_and_len 161 | 162 | if self.allow_mixed_task_among_acc: 163 | global_batched_indices = [ 164 | indices[i : i + global_batch_size] 165 | for group_name, indices in group_indices.items() 166 | for i in range(0, len(indices), global_batch_size) 167 | ] 168 | else: 169 | global_batched_indices = [] 170 | for group_name, indices in group_indices.items(): 171 | group_batched_indices = [ 172 | indices[i : i + global_batch_size] for i in range(0, len(indices), global_batch_size) 173 | ] 174 | rng.shuffle(group_batched_indices) 175 | group_batched_indices = [ 176 | sum(group_batched_indices[i : i + self.acc_grad], start=[]) 177 | for i in range(0, len(group_batched_indices), self.acc_grad) 178 | ] 179 | global_batched_indices.extend(group_batched_indices) 180 | rng.shuffle(global_batched_indices) 181 | indices = [_ for batch_indices in global_batched_indices for _ in batch_indices] 182 | else: 183 | raise NotImplementedError() 184 | 185 | assert len(indices) == self.total_size 186 | 187 | own_indices = [] 188 | for start_pos in range(self.rank * self.batch_size, len(indices), self.num_replicas * self.batch_size): 189 | own_indices += indices[start_pos : start_pos + self.batch_size] 190 | # subsample 191 | assert len(own_indices) == self.num_samples 192 | 193 | if self.start_iter * self.batch_size > len(own_indices): 194 | own_indices = [] 195 | else: 196 | own_indices = own_indices[self.start_iter * self.batch_size :] 197 | 198 | return iter(own_indices) 199 | 200 | def __len__(self) -> int: 201 | return self.num_samples 202 | 203 | def set_epoch(self, epoch: int, start_iter: int = 0) -> None: 204 | r""" 205 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 206 | use a different random ordering for each epoch. Otherwise, the next iteration of this 207 | sampler will yield the same ordering. 208 | 209 | Args: 210 | epoch (int): Epoch number. 211 | start_iter (int): start iter number. 212 | """ 213 | self.epoch = epoch 214 | self.start_iter = start_iter 215 | -------------------------------------------------------------------------------- /xllmx/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/model/__init__.py -------------------------------------------------------------------------------- /xllmx/model/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /xllmx/model/tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | from typing import List, Optional 5 | 6 | from sentencepiece import SentencePieceProcessor 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ["Tokenizer", "probe_tokenizer_path_from_pretrained"] 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class Tokenizer: 16 | def __init__(self, model_path: str): 17 | """ 18 | Create a tokenizer, with inner implementation either spm or HF transformers tokenzier 19 | :param model_path: 20 | - when using spm tokenizer, should be path to a sentencepiece model with suffix `.model` 21 | - when using huggingface transformers tokenizer, should be an HF model repo or a local directory, 22 | containing tokenizer.json and tokenizer_config.json. 23 | """ 24 | if model_path.endswith(".model"): # spm tokenizer 25 | self.tokenizer_type = "spm" 26 | # reload tokenizer 27 | assert os.path.isfile(model_path), model_path 28 | self.tokenizer = SentencePieceProcessor(model_file=model_path) 29 | logger.info(f"Reloaded SentencePiece model from {model_path}") 30 | 31 | # BOS / EOS token IDs 32 | self.bos_id: int = self.tokenizer.bos_id() 33 | self.eos_id: int = self.tokenizer.eos_id() 34 | assert self.tokenizer.vocab_size() == self.tokenizer.get_piece_size() 35 | else: 36 | self.tokenizer_type = "transformers" 37 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 38 | logger.info(f"load HF transformers tokenizer from {model_path}") 39 | # BOS / EOS token IDs 40 | self.bos_id: int = self.tokenizer.bos_token_id 41 | if self.bos_id is None: 42 | self.bos_id = self.tokenizer.eos_token_id 43 | self.eos_id: int = self.tokenizer.eos_token_id 44 | assert self.eos_id is not None 45 | 46 | self._probe_tokenizer_style() 47 | 48 | logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") 49 | 50 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 51 | assert type(s) is str 52 | if self.tokenizer_type == "transformers": 53 | t = self.tokenizer.encode(s, truncation=False, add_special_tokens=False) 54 | else: 55 | t = self.tokenizer.encode(s) 56 | if bos: 57 | t = [self.bos_id] + t 58 | if eos: 59 | t = t + [self.eos_id] 60 | return t 61 | 62 | def encode_segment(self, s: str): 63 | s = s.lstrip(" ") 64 | if self.need_space_before_segment: 65 | return self.encode(" " + s, bos=False, eos=False) 66 | else: 67 | return self.encode(s, bos=False, eos=False) 68 | 69 | def encode_wo_prefix_space(self, s: str): 70 | if self.need_space_before_segment: 71 | return self.encode(s, bos=False, eos=False) 72 | else: 73 | # prefix chars that, when preceding other strings without seperator in between, 74 | # are relatively more likely to be tokenized independently rather than getting 75 | # merged into the following strings. 76 | l_prefix = ["@", "\n", "\\", "=", ">", "`"] 77 | for prefix in l_prefix: 78 | prefix_tokens = self.encode(prefix, bos=False, eos=False) 79 | cat_tokens = self.encode(prefix + s, bos=False, eos=False) 80 | if cat_tokens[: len(prefix_tokens)] == prefix_tokens: 81 | return cat_tokens[len(prefix_tokens) :] 82 | 83 | raise NotImplementedError( 84 | f"All prefixes are merged into {s} during tokenization," 85 | f"This is wierd behavior, please open an issue to report this problem", 86 | ) 87 | 88 | def _probe_tokenizer_style(self): 89 | """ 90 | Given a sentence, e.g. "Hi my darling", some tokenizers (e.g. LLaMA's) will pose the following behavior: 91 | >>> # leading characters will be treated as if there were an " " in the beginning 92 | >>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode("my darling") 93 | >>> # leading space " " is redundant and should not be added 94 | >>> tokenizer.encode("Hi my darling") != tokenizer.encode("Hi") + tokenizer.encode(" my darling") 95 | However, some others (e.g. InternLM's) will behave differently: 96 | >>> # leading space " " has to be explicitly added 97 | >>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode(" my darling") 98 | Knowing which style the tokenizer takes is necessary when tokenzing a segment cut from the complete 99 | text, so that the result is the same as the corresponding part in the tokenized original text. 100 | """ 101 | sentence1 = self.encode("Hi my darling", bos=False, eos=False) 102 | sentence2 = self.encode("my darling", bos=False, eos=False) 103 | if sentence1[-len(sentence2) :] == sentence2: 104 | self.need_space_before_segment = False 105 | else: 106 | sentence3 = self.encode(" my darling", bos=False, eos=False) 107 | assert sentence1[-len(sentence3) :] == sentence3 108 | self.need_space_before_segment = True 109 | 110 | def decode(self, t: List[int]) -> str: 111 | return self.tokenizer.decode(t) 112 | 113 | def save(self, save_dir: str): 114 | if self.tokenizer_type == "transformers": 115 | self.tokenizer.save_pretrained(save_dir) 116 | else: 117 | with open(Path(save_dir) / "tokenizer.model", "wb") as f: 118 | f.write(self.tokenizer.serialized_model_proto()) 119 | 120 | @property 121 | def n_words(self): 122 | if self.tokenizer_type == "spm": 123 | return self.tokenizer.vocab_size() 124 | elif self.tokenizer_type == "transformers": 125 | return len(self.tokenizer) 126 | else: 127 | raise RuntimeError 128 | 129 | 130 | def probe_tokenizer_path_from_pretrained(pretrained_path: str): 131 | tokenizer_path = None 132 | 133 | # try find spm-style tokenizer 134 | logger.info(f"trying to find sentencepiece-style tokenizer at {Path(pretrained_path) / 'tokenizer.model'}") 135 | if (Path(pretrained_path) / "tokenizer.model").exists(): 136 | logger.info(f"Found {Path(pretrained_path) / 'tokenizer.model'}, use it.") 137 | tokenizer_path = str(Path(pretrained_path) / "tokenizer.model") 138 | else: 139 | logger.info("Not Found") 140 | 141 | # then try huggingface style 142 | if tokenizer_path is None: 143 | logger.info( 144 | f"trying to find huggingface-style tokenizer at " 145 | f"{Path(pretrained_path) / '(tokenizer.json, tokenizer_config.json)'}" 146 | ) 147 | if (Path(pretrained_path) / "tokenizer.json").exists() and ( 148 | Path(pretrained_path) / "tokenizer_config.json" 149 | ).exists(): 150 | logger.info(f"Found {Path(pretrained_path) / '(tokenizer.json, tokenizer_config.json)'}, use them.") 151 | tokenizer_path = pretrained_path 152 | else: 153 | logger.info("Not Found") 154 | if tokenizer_path is None: 155 | logger.info("No usable tokenizer found") 156 | return tokenizer_path 157 | -------------------------------------------------------------------------------- /xllmx/solvers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/solvers/__init__.py -------------------------------------------------------------------------------- /xllmx/solvers/finetune/__init__.py: -------------------------------------------------------------------------------- 1 | from .finetune import FinetuneSolverBase 2 | -------------------------------------------------------------------------------- /xllmx/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ckpt, dist 2 | -------------------------------------------------------------------------------- /xllmx/util/ckpt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | from typing import Dict, Optional 6 | 7 | import torch 8 | from torch import distributed as dist 9 | from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, StateDictType 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def split_ckpt_str_into_epoch_iter(ckpt_str: str): 15 | # divide ckpt directory names into epoch and iter parts 16 | parts = ckpt_str.split("-") 17 | epoch = int(parts[0].replace("epoch", "")) 18 | if len(parts) == 2: 19 | iter_part = int(parts[1].replace("iter", "")) 20 | else: 21 | iter_part = None 22 | return epoch, iter_part 23 | 24 | 25 | def remove_early_ckpts(out_dir, max_keep=2): 26 | 27 | if max_keep <= 0: 28 | return 29 | 30 | def ckpt_sort_key(s): 31 | # divide ckpt directory names into epoch and iter parts 32 | epoch, iteration = split_ckpt_str_into_epoch_iter(s) 33 | if iteration is None: 34 | iteration = float("inf") 35 | return epoch, iteration 36 | 37 | existing_checkpoints = [_ for _ in os.listdir(out_dir) if "epoch" in _] 38 | existing_checkpoints = sorted(existing_checkpoints, key=ckpt_sort_key, reverse=True) 39 | 40 | for dir_to_remove in existing_checkpoints[max_keep:]: 41 | dir_to_remove = os.path.join(out_dir, dir_to_remove) 42 | shutil.rmtree(dir_to_remove) 43 | logger.info(f"Deleted {dir_to_remove}") 44 | 45 | 46 | def save( 47 | output_dir, 48 | is_main_process, 49 | model: FSDP, 50 | optimizer: Optional[torch.optim.Optimizer] = None, 51 | tokenizer=None, 52 | args=None, 53 | epoch=None, 54 | iteration=None, 55 | additional_rank_common: Optional[Dict] = None, 56 | additional_rank_specific: Optional[Dict] = None, 57 | max_keep=2, 58 | ): 59 | save_name = f"epoch{epoch}" 60 | if iteration is not None: 61 | save_name += f"-iter{iteration}" 62 | save_dir = os.path.join(output_dir, save_name) 63 | 64 | os.makedirs(save_dir, exist_ok=True) 65 | 66 | # save model 67 | with FSDP.state_dict_type( 68 | model, 69 | StateDictType.FULL_STATE_DICT, 70 | FullStateDictConfig(rank0_only=True, offload_to_cpu=True), 71 | ): 72 | # run saving in separate functions to save memory 73 | def _save_model(): 74 | save_dtype = { 75 | "fp16": torch.float16, 76 | "bf16": torch.bfloat16, 77 | "tf32": torch.float, 78 | }[ 79 | args.precision 80 | ] # todo make saving precision optional 81 | if getattr(args, "only_save_trainable", False): 82 | model_trainable_params = model.get_trainable_params() 83 | model_trainable_params = [ 84 | ".".join([_ for _ in key.split(".") if not _.startswith("_")]) 85 | for key in model_trainable_params.keys() 86 | ] 87 | consolidated_model_state_dict = { 88 | key: val.to(save_dtype) for key, val in model.state_dict().items() if key in model_trainable_params 89 | } 90 | else: 91 | consolidated_model_state_dict = {key: val.to(save_dtype) for key, val in model.state_dict().items()} 92 | 93 | if is_main_process: 94 | model.save_pretrained(save_dir, state_dict=consolidated_model_state_dict) 95 | 96 | _save_model() 97 | logger.info("model saved") 98 | 99 | # save optimizer 100 | if optimizer is not None: 101 | with FSDP.state_dict_type( 102 | model, 103 | StateDictType.LOCAL_STATE_DICT, 104 | ): 105 | opt_path = os.path.join( 106 | save_dir, 107 | f"optimizer.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth", 108 | ) 109 | torch.save(optimizer.state_dict(), opt_path) 110 | logger.info("optimizer saved") 111 | else: 112 | logger.info("optimizer is None, skip saving") 113 | 114 | if additional_rank_specific is not None: 115 | torch.save( 116 | additional_rank_specific, 117 | os.path.join(save_dir, f"additional.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth"), 118 | ) 119 | logger.info(f"additional_rank_specific {list(additional_rank_specific.keys())} saved") 120 | 121 | if not is_main_process: 122 | dist.barrier() 123 | return 124 | 125 | # =========The followings are for main process only========= 126 | if tokenizer is not None: 127 | tokenizer.save(save_dir) 128 | logger.info("tokenizer saved") 129 | else: 130 | logger.info("tokenizer is None, skip saving") 131 | 132 | if args is not None: 133 | with open(os.path.join(save_dir, "args.json"), "w") as f: 134 | json.dump(vars(args), f, indent=2) 135 | logger.info("args saved") 136 | else: 137 | logger.info("args is None, skip saving") 138 | 139 | if additional_rank_common is not None: 140 | torch.save(additional_rank_common, os.path.join(save_dir, "additional_rank_common.pth")) 141 | logger.info(f"additional_resources {list(additional_rank_common.keys())} saved") 142 | 143 | remove_early_ckpts(output_dir, max_keep=max_keep) 144 | 145 | dist.barrier() 146 | return 147 | -------------------------------------------------------------------------------- /xllmx/util/dist.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import socket 5 | import subprocess 6 | import time 7 | from types import SimpleNamespace 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | from xllmx.util.misc import random_seed 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def find_free_port(start_port: int, end_port: int): 18 | """ 19 | Find a free port within the specified range. 20 | """ 21 | for port in range(start_port, end_port): 22 | try: 23 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 24 | s.bind(("", port)) # Try to bind to the port 25 | s.close() # Close the socket if successful 26 | return port 27 | except OSError as e: 28 | # print(f"Port {port} is in use, trying next port.") 29 | continue 30 | raise RuntimeError(f"No free ports found in range {start_port}-{end_port}") 31 | 32 | 33 | def init_distributed_mode(args=SimpleNamespace()): 34 | random_seed(getattr(args, "seed", 0)) 35 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ: 36 | args.world_size = int(os.environ["WORLD_SIZE"]) 37 | args.rank = int(os.environ["RANK"]) 38 | args.gpu = int(os.environ["LOCAL_RANK"]) 39 | args.local_rank = args.gpu 40 | args.dist_url = "env://" 41 | elif "SLURM_PROCID" in os.environ: 42 | os.environ["MASTER_PORT"] = "8966" 43 | while "MASTER_ADDR" not in os.environ or len(os.environ["MASTER_ADDR"].strip()) == 0: 44 | os.environ["MASTER_ADDR"] = ( 45 | subprocess.check_output( 46 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 47 | shell=True, 48 | ) 49 | .decode() 50 | .strip() 51 | ) 52 | time.sleep(1) 53 | print(os.environ["MASTER_ADDR"]) 54 | args.world_size = int(os.environ["SLURM_NPROCS"]) 55 | args.rank = int(os.environ["SLURM_PROCID"]) 56 | args.gpu = args.rank % torch.cuda.device_count() 57 | args.local_rank = args.gpu 58 | args.dist_url = "env://" 59 | os.environ["LOCAL_RANK"] = str(args.gpu) 60 | os.environ["WORLD_SIZE"] = str(args.world_size) 61 | os.environ["RANK"] = str(args.rank) 62 | else: 63 | os.environ["MASTER_ADDR"] = "127.0.0.1" 64 | os.environ["MASTER_PORT"] = str(find_free_port(9000, 10000)) 65 | os.environ["RANK"] = "0" 66 | os.environ["LOCAL_RANK"] = "0" 67 | os.environ["WORLD_SIZE"] = "1" 68 | args.rank = 0 69 | args.gpu = args.local_rank = 0 70 | args.world_size = 1 71 | args.dist_url = "env://" 72 | 73 | args.distributed = True 74 | 75 | torch.cuda.set_device(args.gpu) 76 | args.dist_backend = "nccl" 77 | print("| distributed init (rank {}): {}, gpu {}".format(args.rank, args.dist_url, args.gpu), flush=True) 78 | torch.distributed.init_process_group( 79 | backend=args.dist_backend, 80 | init_method=args.dist_url, 81 | world_size=args.world_size, 82 | rank=args.rank, 83 | timeout=datetime.timedelta(seconds=2 * 60 * 60), 84 | ) 85 | torch.distributed.barrier() 86 | 87 | 88 | def all_reduce_mean(x, group=None): 89 | world_size = dist.get_world_size(group=group) 90 | if world_size > 1: 91 | if isinstance(x, torch.Tensor): 92 | x_reduce = x.clone().cuda() 93 | else: 94 | x_reduce = torch.tensor(x).cuda() 95 | dist.all_reduce(x_reduce, group=group) 96 | x_reduce /= world_size 97 | return x_reduce.item() 98 | else: 99 | return x 100 | -------------------------------------------------------------------------------- /xllmx/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, it, args): 5 | """Decay the learning rate with half-cycle cosine after warmup""" 6 | if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps 7 | lr = args.lr * it / args.warmup_iters 8 | elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate 9 | lr = args.min_lr 10 | else: # 3) in between, use cosine decay down to min learning rate 11 | decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters) 12 | assert 0 <= decay_ratio <= 1 13 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 14 | lr = args.min_lr + (args.lr - args.min_lr) * coeff 15 | 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | 23 | 24 | def adjust_learning_rate_epoch(optimizer, epoch, args): 25 | """Decay the learning rate with half-cycle cosine after warmup""" 26 | if epoch < args.warmup_epochs: 27 | lr = args.lr * epoch / args.warmup_epochs 28 | else: 29 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 30 | 1.0 + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)) 31 | ) 32 | for param_group in optimizer.param_groups: 33 | if "lr_scale" in param_group: 34 | param_group["lr"] = lr * param_group["lr_scale"] 35 | else: 36 | param_group["lr"] = lr 37 | return lr 38 | -------------------------------------------------------------------------------- /xllmx/util/misc.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | import datetime 3 | import logging 4 | import random 5 | import time 6 | 7 | from fairscale.nn.model_parallel import initialize as fs_init 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def random_seed(seed=0): 16 | random.seed(seed) 17 | torch.random.manual_seed(seed) 18 | np.random.seed(seed) 19 | 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window_size=1000, fmt=None): 27 | if fmt is None: 28 | fmt = "{avg:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window_size) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value 75 | ) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, **kwargs): 84 | for k, v in kwargs.items(): 85 | if v is None: 86 | continue 87 | elif isinstance(v, (torch.Tensor, float, int)): 88 | self.meters[k].update(v.item() if isinstance(v, torch.Tensor) else v) 89 | elif isinstance(v, list): 90 | for i, sub_v in enumerate(v): 91 | self.meters[f"{k}_{i}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v) 92 | elif isinstance(v, dict): 93 | for sub_key, sub_v in v.items(): 94 | self.meters[f"{k}_{sub_key}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v) 95 | else: 96 | raise TypeError(f"Unsupported type {type(v)} for metric {k}") 97 | 98 | def __str__(self): 99 | loss_str = [] 100 | for name, meter in self.meters.items(): 101 | loss_str.append("{}: {}".format(name, str(meter))) 102 | return self.delimiter.join(loss_str) 103 | 104 | def synchronize_between_processes(self): 105 | for meter in self.meters.values(): 106 | meter.synchronize_between_processes() 107 | 108 | def add_meter(self, name, meter): 109 | self.meters[name] = meter 110 | 111 | def log_every(self, iterable, print_freq, header=None, start_iter=0, samples_per_iter=None): 112 | i = start_iter 113 | if not header: 114 | header = "" 115 | start_time = time.time() 116 | end = time.time() 117 | iter_time = SmoothedValue(fmt="{avg:.4f}") 118 | data_time = SmoothedValue(fmt="{avg:.4f}") 119 | log_msg = [header, "[{0" + "}/{1}]", "{meters}", "time: {time}", "data: {data}"] 120 | if samples_per_iter is not None: 121 | log_msg.append("samples/sec: {samples_per_sec:.2f}") 122 | if torch.cuda.is_available(): 123 | log_msg.append("max mem: {memory:.0f}") 124 | log_msg = self.delimiter.join(log_msg) 125 | MB = 1024.0 * 1024.0 126 | for obj in iterable: 127 | data_time.update(time.time() - end) 128 | yield obj 129 | iter_time.update(time.time() - end) 130 | if i % print_freq == 0: 131 | try: 132 | total_len = len(iterable) 133 | except: 134 | total_len = "unknown" 135 | 136 | msg_kwargs = { 137 | "meters": str(self), 138 | "time": str(iter_time), 139 | "data": str(data_time), 140 | } 141 | if samples_per_iter is not None: 142 | msg_kwargs["samples_per_sec"] = samples_per_iter / iter_time.avg 143 | if torch.cuda.is_available(): 144 | msg_kwargs["memory"] = torch.cuda.max_memory_allocated() / MB 145 | 146 | logger.info(log_msg.format(i, total_len, **msg_kwargs)) 147 | i += 1 148 | end = time.time() 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | logger.info("{} Total time: {}".format(header, total_time_str)) 152 | 153 | 154 | def add_weight_decay(model, lr, weight_decay=1e-5): 155 | decay = [] 156 | no_decay = [] 157 | for name, param in model.named_parameters(): 158 | if not param.requires_grad: 159 | continue # frozen weights 160 | if name.endswith(".bias") or name.endswith("norm.weight"): 161 | no_decay.append(param) 162 | else: 163 | decay.append(param) 164 | return [ 165 | {"params": no_decay, "lr": lr, "weight_decay": weight_decay}, 166 | {"params": decay, "lr": lr, "weight_decay": weight_decay}, 167 | ] 168 | 169 | 170 | def broadcast_nonmp_parameters(model): 171 | if fs_init.get_model_parallel_world_size() == 1: 172 | return 173 | logger.info("starting broadcast non-model-parallel parameters within model parallel group") 174 | memo = set() 175 | modules = model.named_modules(prefix="", remove_duplicate=True) 176 | for module_prefix, module in modules: 177 | members = dict(module._parameters.items()) 178 | for k, v in members.items(): 179 | name = module_prefix + ("." if module_prefix else "") + k 180 | if v is None or v in memo: 181 | continue 182 | if getattr(v, "model_parallel", False): 183 | logger.info(f"ignore: {name}") 184 | continue 185 | memo.add(v) 186 | dist.broadcast(v, src=fs_init.get_model_parallel_src_rank(), group=fs_init.get_model_parallel_group()) 187 | logger.info("braodcast done") 188 | 189 | 190 | def mark_mp_params(model: torch.nn.Module): 191 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear 192 | 193 | for m in model.modules(): 194 | if isinstance(m, ColumnParallelLinear): 195 | m.weight.model_parallel = True 196 | if m.bias is not None: 197 | m.bias.model_parallel = True 198 | 199 | if isinstance(m, RowParallelLinear): 200 | m.weight.model_parallel = True 201 | 202 | if isinstance(m, ParallelEmbedding): 203 | m.weight.model_parallel = True 204 | 205 | 206 | def print_param_status(model: torch.nn.Module) -> None: 207 | require_grad_set = [] 208 | no_grad_set = [] 209 | for name, param in model.named_parameters(): 210 | if param.requires_grad: 211 | require_grad_set.append((name, param)) 212 | else: 213 | no_grad_set.append((name, param)) 214 | 215 | logger.info("Params that require gradient:\n") 216 | for name, param in require_grad_set: 217 | model_parallel = getattr(param, "model_parallel", False) 218 | logger.info( 219 | f"Param {name}: requires_grad {param.requires_grad}, local_size {param.shape}, model_parallel {model_parallel}, dtype {param.dtype}" 220 | ) 221 | 222 | logger.info("\nParams that do not require gradient:\n") 223 | for name, param in no_grad_set: 224 | model_parallel = getattr(param, "model_parallel", False) 225 | logger.info( 226 | f"Param {name}: requires_grad {param.requires_grad}, local_size {param.shape}, model_parallel {model_parallel}, dtype {param.dtype}" 227 | ) 228 | -------------------------------------------------------------------------------- /xllmx/util/tensor_type.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def promote_param_to_fp32(param: nn.Parameter) -> None: 6 | if param.is_floating_point() and torch.finfo(param.dtype).bits < 32: 7 | param.data = param.data.float() 8 | if param.is_complex() and torch.finfo(param.dtype).bits < 32: 9 | param.data = param.data.to(torch.complex64) 10 | --------------------------------------------------------------------------------