├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── INSTALL.md
├── README.md
├── assets
├── config1.jpg
├── config2.jpg
├── demos.png
└── logo.png
├── lumina_mgpt
├── .isort.cfg
├── TRAIN.md
├── configs
│ └── data
│ │ └── sample.yaml
├── data
│ ├── __init__.py
│ ├── convertsation.py
│ └── item_processor.py
├── demos
│ ├── demo_freeform.py
│ ├── demo_image2image.py
│ └── demo_image_generation.py
├── exps
│ └── 7B.sh
├── finetune_solver.py
├── generate_examples
│ └── generate.py
├── inference_solver.py
├── model
│ ├── __init__.py
│ ├── chameleon
│ │ ├── __init__.py
│ │ ├── configuration_chameleon.py
│ │ ├── convert_chameleon_weights_to_hf.py
│ │ ├── image_processing_chameleon.py
│ │ ├── modeling_chameleon.py
│ │ └── processing_chameleon.py
│ ├── chameleon_vae_ori
│ │ ├── __init__.py
│ │ ├── image_tokenizer.py
│ │ ├── vocab.py
│ │ └── vqgan.py
│ ├── configuration_xllmx_chameleon.py
│ └── modeling_xllmx_chameleon.py
└── pre_tokenize
│ ├── concat_record.py
│ └── pre_tokenize.py
├── requirements.txt
├── setup.py
└── xllmx
├── __init__.py
├── data
├── __init__.py
├── conversation
│ ├── __init__.py
│ └── template.py
├── data_reader.py
├── dataset.py
├── item_processor.py
└── sampler.py
├── model
├── __init__.py
├── components.py
└── tokenizer.py
├── solvers
├── __init__.py
└── finetune
│ ├── __init__.py
│ └── finetune.py
└── util
├── __init__.py
├── ckpt.py
├── dist.py
├── lr_sched.py
├── misc.py
└── tensor_type.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # Project-specific gitignores
163 | .output
164 | xllmx/output/
165 | xllmx/output_dir/
166 | tokenizer.model
167 | *.swp
168 |
169 | .asset/
170 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile = black
3 | line_length = 120
4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
5 | no_lines_before = STDLIB,LOCALFOLDER
6 | lines_between_types = 1
7 | combine_as_imports = True
8 | force_sort_within_sections = true
9 | order_by_type = True
10 | known_first_party = xllmx
11 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Exclude all third-party libraries and auto-generated files globally
2 | #exclude:
3 | repos:
4 | # Common hooks
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v4.6.0
7 | hooks:
8 | - id: check-merge-conflict
9 | - id: check-symlinks
10 | - id: detect-private-key
11 | - id: end-of-file-fixer
12 | - id: trailing-whitespace
13 | files: (.*\.(py|bzl|md|rst|c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps|cmake|yaml|yml|hook)|BUILD|.*\.BUILD|WORKSPACE|CMakeLists\.txt)$
14 | # For Python files
15 | - repo: https://github.com/PyCQA/isort
16 | rev: 5.13.2
17 | hooks:
18 | - id: isort
19 | - repo: https://github.com/psf/black.git
20 | rev: 24.4.2
21 | hooks:
22 | - id: black
23 | files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
24 | args: [--line-length=120]
25 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 |
2 | ### 1. Basic Setup
3 |
4 | ```
5 | # Create a new conda environment named 'lumina_mgpt' with Python 3.10
6 | conda create -n lumina_mgpt python=3.10 -y
7 | # Activate the 'lumina_mgpt' environment
8 | conda activate lumina_mgpt
9 | # Install required packages from 'requirements.txt'
10 | pip install -r requirements.txt
11 | ```
12 |
13 | ### 2. Install Flash-Attention
14 | ```
15 | pip install flash-attn --no-build-isolation
16 | ```
17 |
18 | ### 3. Install xllmx as Python Package
19 | The [xllmx](./xllmx) module is a lightweight engine designed to support the training and inference of
20 | LLM-centered Any2Any models. It is evolved from [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), undergoing comprehensive improvements to achieve higher efficiency and
21 | wider functionality, including the support for flexible arrangement and processing of interleaved media and text.
22 |
23 | The Lumina-mGPT implementation heavily relies on xllmx and requires xllmx to be installed as a python package (**so that `import xllmx` can be used anywhere in your machine, without the restriction of working directory**).
24 | The installation process is as follows:
25 | ```bash
26 | # bash
27 | # go to the root path of the project
28 | cd Lumina_mGPT
29 | # install as package
30 | pip install -e .
31 | ```
32 |
33 | ### 4. Optional: Install Apex
34 | > [!Caution]
35 | >
36 | > If you merely run inference, there is no need to install Apex.
37 | >
38 | > For training, Apex can bring some training efficiency improvement, but it is still not a must.
39 | >
40 | > Note that training works smoothly with either:
41 | > 1. Apex not installed at all; OR
42 | > 2. Apex successfully installed with CUDA and C++ extensions.
43 | >
44 | > However, it will fail when:
45 | > 1. A Python-only build of Apex is installed.
46 | >
47 | > If errors like `No module named 'fused_layer_norm_cuda'` are reported, it generally means that you are
48 | using a Python-only Apex build. Please run `pip uninstall apex` to remove the build and try again.
49 |
50 | Lumina-mGPT utilizes [apex](https://github.com/NVIDIA/apex) to accelerate training, which needs to be compiled from source. Please follow the [official instructions](https://github.com/NVIDIA/apex#from-source) for installation.
51 | Here are some tips based on our experiences:
52 |
53 | **Step1**: Check the version of CUDA with which your torch is built:
54 | ```python
55 | # python
56 | import torch
57 | print(torch.version.cuda)
58 | ```
59 |
60 | **Step2**: Check the CUDA toolkit version on your system:
61 | ```bash
62 | # bash
63 | nvcc -V
64 | ```
65 | **Step3**: If the two aforementioned versions mismatch, or if you do not have CUDA toolkit installed on your system,
66 | please download and install CUDA toolkit from [here](https://developer.nvidia.com/cuda-toolkit-archive) with version matching the torch CUDA version.
67 |
68 | > [!Note]
69 | >
70 | > Note that multiple versions of CUDA toolkit can co-exist on the same machine, and the version can be easily switched by changing the `$PATH` and `$LD_LIBRARY_PATH` environment variables.
71 | There is thus no need to worry about your machine's environment getting messed up.
72 |
73 | **Step4**: You can now start installing apex:
74 | ```bash
75 | git clone https://github.com/NVIDIA/apex
76 | cd apex
77 | # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
78 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
79 | # otherwise
80 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
81 | ```
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 | # Lumina-mGPT
6 |
7 |
A family of multimodal autoregressive models capable of various vision and language tasks, particularly excelling in generating flexible photorealistic images from text descriptions. 👋 join our WeChat
8 |
9 | [](https://arxiv.org/abs/2408.02657)
10 |
11 | [-6B88E3?logo=youtubegaming&label=Demo%20Lumina-mGPT)](http://106.14.2.150:10020/)
12 | [-6B88E3?logo=youtubegaming&label=Demo%20Lumina-mGPT)](http://106.14.2.150:10021/)
13 |
14 |
15 |
16 |
17 |
18 | ## 📰 News
19 |
20 | - **[2024-08-11] 🎉🎉🎉 [Training codes and documents](./lumina_mgpt/TRAIN.md) are released! 🎉🎉🎉**
21 |
22 | - **[2024-07-08] 🎉🎉🎉 Lumina-mGPT is released! 🎉🎉🎉**
23 |
24 | ## ⚙️ Installation
25 |
26 | See [INSTALL.md](./INSTALL.md) for detailed instructions.
27 |
28 | Note that the Lumina-mGPT implementation heavily relies on
29 | the [xllmx](./xllmx) module, which is evolved from [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory) for supporting
30 | LLM-centered multimodal tasks. Make sure it is installed correctly as a python package before going on.
31 |
32 | ## ⛽ Training
33 | See [lumina_mgpt/TRAIN.md](lumina_mgpt/TRAIN.md)
34 |
35 | ## 📽️ Inference
36 |
37 | > [!Note]
38 | >
39 | > Before using the Lumina-mGPT model, run
40 | >
41 | > ```bash
42 | > # bash
43 | > cd lumina_mgpt
44 | > ```
45 | >
46 | > to enter the directory of the Lumina-mGPT implementation.
47 |
48 | ### Perpetration
49 |
50 | Since currently the Chameleon implementation in transformers does not contain the VQ-VAE decoder, please manually download the original VQ-VAE weights [provided by Meta](https://github.com/facebookresearch/chameleon) and
51 | put them to the following directory:
52 |
53 | ```
54 | Lumina-mGPT
55 | - lumina_mgpt/
56 | - ckpts/
57 | - chameleon/
58 | - tokenizer/
59 | - text_tokenizer.json
60 | - vqgan.yaml
61 | - vqgan.ckpt
62 | - xllmx/
63 | - ...
64 | ```
65 |
66 | ### Local Gradio Demos
67 |
68 | We have prepared three different Gradio demos, each showcasing unique functionalities, to help you quickly become familiar with the capabilities of the Lumina-mGPT models.
69 |
70 | #### 1. [demos/demo_image_generation.py](./Lumina-mGPT/demos/demo_image_generation.py)
71 |
72 | This demo is customized for Image Generation tasks, where you can input a text description and generate a corresponding image.
73 | To host this demo, run:
74 |
75 | ```bash
76 | # Note to set the `--target_size` argument consistent with the checkpoint
77 | python -u demos/demo_image_generation.py \
78 | --pretrained_path Alpha-VLLM/Lumina-mGPT-7B-768 \
79 | --target_size 768
80 | ```
81 |
82 | #### 2. [demos/demo_image2image.py](./Lumina-mGPT/demos/demo_image2image.py)
83 |
84 | This demo is designed for models trained with Omni-SFT. you can conveniently switch between the multiple downstream tasks using this demo.
85 |
86 | ```bash
87 | # Note to set the `--target_size` argument consistent with the checkpoint
88 | python -u demos/demo_image2image.py \
89 | --pretrained_path Alpha-VLLM/Lumina-mGPT-7B-768-Omni \
90 | --target_size 768
91 | ```
92 |
93 | #### 3. [demos/demo_freeform.py](./Lumina-mGPT/demos/demo_freeform.py)
94 |
95 | This is a powerful demo with minimal constraint on the input format. It supports flexible interation and is suitable for in-deep exploration.
96 |
97 | ```bash
98 | # Note to set the `--target_size` argument consistent with the checkpoint
99 | python -u demos/demo_freeform.py \
100 | --pretrained_path Alpha-VLLM/Lumina-mGPT-7B-768-Omni \
101 | --target_size 768
102 | ```
103 |
104 | ### Simple Inference
105 |
106 | The simplest code for Lumina-mGPT inference:
107 |
108 | ```python
109 | from inference_solver import FlexARInferenceSolver
110 | from PIL import Image
111 |
112 | # ******************** Image Generation ********************
113 | inference_solver = FlexARInferenceSolver(
114 | model_path="Alpha-VLLM/Lumina-mGPT-7B-768",
115 | precision="bf16",
116 | target_size=768,
117 | )
118 |
119 | q1 = f"Generate an image of 768x768 according to the following prompt:\n"
120 | f"Image of a dog playing water, and a waterfall is in the background."
121 |
122 | # generated: tuple of (generated response, list of generated images)
123 | generated = inference_solver.generate(
124 | images=[],
125 | qas=[[q1, None]],
126 | max_gen_len=8192,
127 | temperature=1.0,
128 | logits_processor=inference_solver.create_logits_processor(cfg=4.0, image_top_k=2000),
129 | )
130 |
131 | a1, new_image = generated[0], generated[1][0]
132 |
133 |
134 | # ******************* Image Understanding ******************
135 | inference_solver = FlexARInferenceSolver(
136 | model_path="Alpha-VLLM/Lumina-mGPT-7B-512",
137 | precision="bf16",
138 | target_size=512,
139 | )
140 |
141 | # "<|image|>" symbol will be replaced with sequence of image tokens before fed to LLM
142 | q1 = "Describe the image in detail. <|image|>"
143 |
144 | images = [Image.open("image.png")]
145 | qas = [[q1, None]]
146 |
147 | # `len(images)` should be equal to the number of appearance of "<|image|>" in qas
148 | generated = inference_solver.generate(
149 | images=images,
150 | qas=qas,
151 | max_gen_len=8192,
152 | temperature=1.0,
153 | logits_processor=inference_solver.create_logits_processor(cfg=4.0, image_top_k=2000),
154 | )
155 |
156 | a1 = generated[0]
157 | # generated[1], namely the list of newly generated images, should typically be empty in this case.
158 |
159 |
160 | # ********************* Omni-Potent *********************
161 | inference_solver = FlexARInferenceSolver(
162 | model_path="Alpha-VLLM/Lumina-mGPT-7B-768-Omni",
163 | precision="bf16",
164 | target_size=768,
165 | )
166 |
167 | # Example: Depth Estimation
168 | # For more instructions, see demos/demo_image2image.py
169 | q1 = "Depth estimation. <|image|>"
170 | images = [Image.open("image.png")]
171 | qas = [[q1, None]]
172 |
173 | generated = inference_solver.generate(
174 | images=images,
175 | qas=qas,
176 | max_gen_len=8192,
177 | temperature=1.0,
178 | logits_processor=inference_solver.create_logits_processor(cfg=1.0, image_top_k=200),
179 | )
180 |
181 | a1 = generated[0]
182 | new_image = generated[1][0]
183 |
184 | ```
185 |
186 | ## 🤗 Checkpoints
187 |
188 | **Configurations**
189 |
190 |
191 |
192 |
193 | **7B models**
194 |
195 | | Model | Size | Huggingface |
196 | | ------------ | ---- | ---------------------------------------------------------------------------------------- |
197 | | FP-SFT@512 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-512](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-512) |
198 | | FP-SFT@768 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-768](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-768) |
199 | | Omni-SFT@768 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-768-Omni](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-768-Omni) |
200 | | FP-SFT@1024 | 7B | [Alpha-VLLM/Lumina-mGPT-7B-1024](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-7B-1024) |
201 |
202 | **34B models**
203 |
204 | | Model | Size | Huggingface |
205 | | ---------- | ---- | ------------------------------------------------------------------------------------ |
206 | | FP-SFT@512 | 34B | [Alpha-VLLM/Lumina-mGPT-34B-512](https://huggingface.co/Alpha-VLLM/Lumina-mGPT-34B-512) |
207 |
208 | More checkpoints coming soon.
209 |
210 | ## 📑 Open-source Plan
211 |
212 | - [X] Inference code
213 | - [X] Training code
214 |
215 | ## 🔥 Open positions
216 | We are hiring interns, postdocs, and full-time researchers at the General Vision Group, Shanghai AI Lab, with a focus on multi-modality and vision foundation models. If you are interested, please contact gaopengcuhk@gmail.com.
217 |
218 | ## 📄 Citation
219 |
220 | ```
221 | @misc{liu2024lumina-mgpt,
222 | title={Lumina-mGPT: Illuminate Flexible Photorealistic Text-to-Image Generation with Multimodal Generative Pretraining},
223 | author={Dongyang Liu and Shitian Zhao and Le Zhuo and Weifeng Lin and Yu Qiao and Hongsheng Li and Peng Gao},
224 | year={2024},
225 | eprint={2408.02657},
226 | archivePrefix={arXiv},
227 | primaryClass={cs.CV},
228 | url={https://arxiv.org/abs/2408.02657},
229 | }
230 | ```
231 |
--------------------------------------------------------------------------------
/assets/config1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/config1.jpg
--------------------------------------------------------------------------------
/assets/config2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/config2.jpg
--------------------------------------------------------------------------------
/assets/demos.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/demos.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/assets/logo.png
--------------------------------------------------------------------------------
/lumina_mgpt/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile = black
3 | line_length = 120
4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
5 | no_lines_before = STDLIB,LOCALFOLDER
6 | lines_between_types = 1
7 | combine_as_imports = True
8 | force_sort_within_sections = true
9 | order_by_type = True
10 | known_first_party = xllmx
11 |
--------------------------------------------------------------------------------
/lumina_mgpt/TRAIN.md:
--------------------------------------------------------------------------------
1 | # Lumina-mGPT Training
2 |
3 | For efficiency considerations, the multi-modal datasets are pre-tokenized into sequences of token ids. This leads to significantly faster training
4 |
5 | ## Pre-tokenization
6 |
7 |
8 | ### 1. Run Tokenization
9 |
10 | This stage tokenizes each data point, consisting of interleaved image and text, into a single sequence of integer tokens. After tokenization, the sequence is saved to disk for trainining-time usage. Together with the saved tokens, a json-formatted record file is also generated for indexing all the saved token files. For faster tokenization, you may use multiple GPUs and dispatch different subsets of data to them.
11 |
12 | #### Command:
13 |
14 | ```bash
15 | for i in {0..7}
16 | do
17 | export CUDA_VISIBLE_DEVICES=${i}
18 | python -u pre_tokenize/pre_tokenize.py \
19 | --splits=8 \
20 | --rank=${i} \
21 | --in_filename /path/to/in_filename.json \
22 | --out_dir /path/to/out_dir \
23 | --target_size 768 &> ${i}.log &
24 | done
25 | ```
26 |
27 | #### Format of Input File:
28 |
29 | `in_filename` is expected to be a json file with the following format:
30 | ```python
31 | [
32 | {...},
33 | {...},
34 | {
35 | "conversations":[
36 | {
37 | "from": "human",
38 | "value": "Hi, please convert this depth image <|image|> to a color image"
39 | },
40 | {
41 | "from": "gpt",
42 | "value": "<|image|>"
43 | },
44 | {
45 | "from": "human",
46 | "value": "Could you change its style to that of this image? <|image|>"
47 | },
48 | {
49 | "from": "gpt",
50 | "value": "Sure, here's the image. <|image|>"
51 | }
52 | ],
53 | "image": ["/path/to/image1.png", "path/to/image2.png", "path/to/image3.png", "path/to/image4.png"]
54 | },
55 | {...},
56 | {...}
57 | ]
58 | ```
59 |
60 | *Rules:*
61 |
62 | 1. The file is a list of dictionaries, and each dictionary represents a data point
63 | 2. Each dictionary contains the key "conversations"
64 | 3. If the conversation involves image(s), the data point should also contain the `image` key, otherwise the `image` key can be omitted
65 | 4. The location of each image should be explicitly specified in the conversation using the `<|image|>` symbol
66 | 1. Apparently, the number of occurrences of the `<|image|>` symbol should be equal to the number of images in the `image` key
67 |
68 |
69 | #### How to adapt to your own format:
70 |
71 | If you have your own data with a different format, you can easily adapt the code to deal with it by modifying the `pre_tokenize.py` file.
72 | We have prepared the space, which is in `ItemProcessor.process_item`, for adding your logic that converts data points of your own format into the standard format.
73 |
74 | ### 2. Concat Records
75 |
76 | After tokenization, You need to concat the record files generated by different processes (GPUs) into one single record file.
77 | **Note that we use the term "record file" to refer to the meta file that contains the information of all the saved token files,
78 | which is different from the token files themselves.**
79 |
80 | ```bash
81 | python -u pre_tokenize/concat_record.py \
82 | --sub_record_dir /path/to/out_dir \
83 | --save_path /path/to/out_dir/record.json
84 | ```
85 |
86 | ## Training
87 |
88 | #### Command:
89 | We provide an example experiment scripts [exps/7B.sh](exps/7B.sh) for training the 7B model. Suppose you have access to a SLURM clsuter, you can run the following command to start training:
90 |
91 | ```bash
92 | srun -n32 --ntasks-per-node=8 --gres=gpu:8 bash exps/7B.sh
93 | ```
94 |
95 | Otherwise, if your want to use `torchrun` for distributed training, you can make the following change to `exps/7B.sh`:
96 | ```bash
97 | # python -u finetune_solver.py \
98 | torchrun --torchrun_kwargs finetune_solver.py \
99 | ```
100 |
101 | #### About the `--data_config` argument:
102 | The ``--data_config`` argument should point to a `*.yaml` file, which is a meta file that gathers one or multiple record files.
103 | In other words, you may pre-tokenize multiple datasets independently and list the record files in the same data config file
104 | for joint training.
105 |
--------------------------------------------------------------------------------
/lumina_mgpt/configs/data/sample.yaml:
--------------------------------------------------------------------------------
1 | META:
2 | - path: 'path/to/record1.json'
3 | - path: 'path/to/record2.json'
4 | ratio: 0.4 # In each epoch, sample only 40% of the data from this dataset
5 | - path: 'path/to/record3.json'
6 |
--------------------------------------------------------------------------------
/lumina_mgpt/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/lumina_mgpt/data/__init__.py
--------------------------------------------------------------------------------
/lumina_mgpt/data/convertsation.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class Conversation:
5 | sep_token = ""
6 | roles = ["Human", "Assistant"]
7 |
8 | def __init__(self, messages=None):
9 | self.messages = messages or []
10 |
11 | def process(self):
12 | ret = ""
13 | pieces = []
14 | for i, (role, message) in enumerate(self.messages):
15 | if message is not None:
16 | turn = message + self.sep_token
17 | ret += turn
18 | if role == self.roles[1]:
19 | pieces.append({"data": turn, "predict": True})
20 | else:
21 | pieces.append({"data": turn, "predict": False})
22 | else:
23 | # generation prompt
24 | assert i == len(self.messages) - 1 and role == self.roles[1], "only last assistant message can be None"
25 |
26 | result = {
27 | "conv": ret, # text involving the complete conversation
28 | "pieces": pieces, # list to help correctly mark the labels
29 | }
30 | return result
31 |
32 | def get_prompt(self):
33 | return self.process()["conv"]
34 |
35 | def append_message(self, role, message):
36 | self.messages.append([role, message])
37 |
38 | def copy(self):
39 | return Conversation(
40 | messages=[[x, y] for x, y in self.messages],
41 | )
42 |
43 | def load_qas(self, qas: List[List[str]]):
44 | """
45 | convert the list of question-answer pairs to a string, which contains the conversation involving all
46 | the questions and answers. When the last answer is None, the returned string is the prompt which
47 | can be used by the model to generate the last answer.
48 | :param qas: [[question1, answer1], [question2, answer2], ..., [questionX, answerX]]
49 | note that the last answer, i.e. answerX, can be None
50 | :return: the prompt
51 | """
52 | self.messages = []
53 | for q, a in qas:
54 | self.append_message(self.roles[0], q)
55 | self.append_message(self.roles[1], a)
56 |
--------------------------------------------------------------------------------
/lumina_mgpt/data/item_processor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import random
4 | from typing import Dict, List
5 |
6 | from PIL import Image
7 | import torch
8 |
9 | from data.convertsation import Conversation
10 | import model.chameleon_vae_ori as chameleon_vae_ori
11 | from xllmx.data.data_reader import read_general
12 | from xllmx.data.item_processor import MMConvItemProcessor
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def center_crop(pil_image, crop_size):
18 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
19 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
20 |
21 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
22 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
23 |
24 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
25 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
26 | crop_right = crop_left + crop_size[0]
27 | crop_lower = crop_upper + crop_size[1]
28 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
29 |
30 |
31 | def var_center_crop(pil_image, crop_size_list, random_top_k=1):
32 | w, h = pil_image.size
33 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
34 | crop_size = random.choice(
35 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
36 | )[1]
37 | return center_crop(pil_image, crop_size)
38 |
39 |
40 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
41 | assert max_ratio >= 1.0
42 | crop_size_list = []
43 | wp, hp = num_patches, 1
44 | while wp > 0:
45 | if max(wp, hp) / min(wp, hp) <= max_ratio:
46 | crop_size_list.append((wp * patch_size, hp * patch_size))
47 | if (hp + 1) * wp <= num_patches:
48 | hp += 1
49 | else:
50 | wp -= 1
51 | return crop_size_list
52 |
53 |
54 | class FlexARItemProcessor(MMConvItemProcessor):
55 | image_start_token = "" # fixed tokens for start and end, so can hardcode
56 | image_end_token = ""
57 | full_sub_sep_token = ""
58 | sub_sub_sep_token = ""
59 | sub_skip_token = ""
60 | new_line_token = ""
61 |
62 | def __init__(
63 | self,
64 | tokenizer="Alpha-VLLM/Lumina-mGPT-7B-768",
65 | conv_template=Conversation,
66 | target_size=512,
67 | ):
68 |
69 | super().__init__(
70 | {
71 | "<|image|>": self.process_image,
72 | },
73 | ["<|image|>"],
74 | tokenizer,
75 | conv_template,
76 | )
77 |
78 | self.patch_size = 32
79 | self.crop_size_list = generate_crop_size_list((target_size // self.patch_size) ** 2, self.patch_size)
80 | logger.info("List of crop sizes:")
81 | for i in range(0, len(self.crop_size_list), 6):
82 | logger.info(" " + "".join([f"{f'{w} x {h}':14s}" for w, h in self.crop_size_list[i : i + 6]]))
83 |
84 | # todo
85 | # currently still use the original image tokenizer provided by Meta rather than transformers
86 | # because the transformers implementation does not contain the vae decoder
87 | self.chameleon_ori_vocab = chameleon_vae_ori.VocabInfo(
88 | json.load(open("./ckpts/chameleon/tokenizer/text_tokenizer.json", encoding="utf8"))["model"]["vocab"]
89 | )
90 | self.chameleon_ori_translation = chameleon_vae_ori.VocabTranslation(self.chameleon_ori_vocab, device="cuda")
91 | self.chameleon_ori_image_tokenizer = chameleon_vae_ori.ImageTokenizer(
92 | cfg_path="./ckpts/chameleon/tokenizer/vqgan.yaml",
93 | ckpt_path="./ckpts/chameleon/tokenizer/vqgan.ckpt",
94 | device="cuda",
95 | )
96 |
97 | @staticmethod
98 | def get_n_grids_token(n_grids):
99 | return f""
100 |
101 | def token2id(self, token: str) -> int:
102 | return self.tokenizer.tokenizer.vocab[token]
103 |
104 | @torch.no_grad()
105 | def process_image(self, image) -> Dict:
106 | if isinstance(image, Image.Image):
107 | pass
108 | else:
109 | image = Image.open(read_general(image))
110 |
111 | image = var_center_crop(image, crop_size_list=self.crop_size_list)
112 |
113 | w_grids, h_grids = image.size[0] // self.patch_size, image.size[1] // self.patch_size
114 |
115 | image_toks = self.chameleon_ori_translation.convert_img2bp2(
116 | self.chameleon_ori_image_tokenizer.img_tokens_from_pil(image)
117 | ).view(-1)
118 |
119 | full_image_toks = image_toks.reshape(image.size[1] // 16, image.size[0] // 16)
120 | new_line_id = self.token2id(self.new_line_token)
121 |
122 | full_image_toks = torch.cat(
123 | (
124 | full_image_toks,
125 | torch.ones(image.size[1] // 16, 1, device=full_image_toks.device, dtype=full_image_toks.dtype)
126 | * new_line_id,
127 | ),
128 | dim=1,
129 | ).flatten()
130 |
131 | result_toks = [
132 | self.token2id(self.image_start_token),
133 | self.token2id(self.get_n_grids_token(h_grids)),
134 | self.token2id(self.get_n_grids_token(w_grids)),
135 | *full_image_toks.tolist(),
136 | self.token2id(self.image_end_token),
137 | ]
138 |
139 | return {"input_ids": result_toks, "labels": result_toks}
140 |
141 | def process_item(self, item, training_mode=False, out_flatten=True):
142 | if not out_flatten:
143 | return super().process_item(item, training_mode=training_mode)
144 |
145 | if training_mode:
146 | tokens, labels = super().process_item(item, training_mode=training_mode)
147 | input_tokens_item = []
148 | modified_labels_item = []
149 | for i, (token_or_media, ori_label) in enumerate(zip(tokens, labels)):
150 | if isinstance(token_or_media, int):
151 | token = token_or_media
152 | input_tokens_item.append(token)
153 | modified_labels_item.append(ori_label)
154 | else:
155 | input_tokens_item += token_or_media["input_ids"]
156 | if ori_label <= 0: # in the prompt part
157 | modified_labels_item += [-100] * len(token_or_media["input_ids"])
158 | else:
159 | modified_labels_item += token_or_media["labels"]
160 |
161 | return input_tokens_item, modified_labels_item
162 | else:
163 | tokens = super().process_item(item, training_mode=training_mode)
164 | input_tokens_item = []
165 | for i, token_or_media in enumerate(tokens):
166 | if isinstance(token_or_media, int):
167 | input_tokens_item.append(token_or_media)
168 | else:
169 | input_tokens_item += token_or_media["input_ids"]
170 |
171 | return input_tokens_item
172 |
173 | def decode_image(self, tokens: List[int]) -> Image.Image:
174 | if tokens[0] == self.token2id(self.image_start_token):
175 | tokens = tokens[1:]
176 | if tokens[-1] == self.token2id(self.image_end_token):
177 | tokens = tokens[:-1]
178 |
179 | h_grids, w_grids = tokens[0] - 8804, tokens[1] - 8804
180 | tokens = tokens[2:]
181 | h, w = h_grids * self.patch_size, w_grids * self.patch_size
182 | h_latent_dim, w_latent_dim = h_grids * 2, w_grids * 2
183 |
184 | for i in range(len(tokens)):
185 | if (i + 1) % (w_latent_dim + 1) != 0:
186 | tokens[i] = self.chameleon_ori_translation.bpe2img[tokens[i]]
187 |
188 | assert len(tokens) == h_latent_dim * (w_latent_dim + 1)
189 | tokens = torch.tensor(tokens, dtype=torch.int64).cuda()
190 |
191 | tokens = tokens.view(h_latent_dim, w_latent_dim + 1)[:, :-1].flatten()
192 |
193 | return self.chameleon_ori_image_tokenizer.pil_from_img_toks(tokens, h_latent_dim, w_latent_dim)
194 |
--------------------------------------------------------------------------------
/lumina_mgpt/demos/demo_freeform.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0])
5 |
6 | import argparse
7 | import builtins
8 | import datetime
9 | import multiprocessing as mp
10 | import traceback
11 | from typing import List, Optional
12 |
13 | import gradio as gr
14 | import torch
15 |
16 | from inference_solver import FlexARInferenceSolver
17 | from xllmx.util.misc import random_seed
18 |
19 |
20 | class Ready:
21 | pass
22 |
23 |
24 | class ModelFailure:
25 | pass
26 |
27 |
28 | def model_worker(
29 | rank: int,
30 | args: argparse.Namespace,
31 | barrier: mp.Barrier,
32 | request_queue: mp.Queue,
33 | response_queue: Optional[mp.Queue] = None,
34 | ) -> None:
35 | """
36 | The worker function that manipulates the GPU to run the inference.
37 | Exact n_gpu workers are started, with each one operating on a separate GPU.
38 |
39 | Args:
40 | rank (int): Distributed rank of the worker.
41 | args (argparse.Namespace): All command line arguments.
42 | barrier (multiprocessing.Barrier): A barrier used to delay the start
43 | of Web UI to be after the start of the model.
44 | """
45 |
46 | builtin_print = builtins.print
47 |
48 | def print(*args, **kwargs):
49 | kwargs["flush"] = True
50 | now = datetime.datetime.now().time()
51 | builtin_print("[{}] ".format(now), end="") # print with time stamp
52 | builtin_print(*args, **kwargs)
53 |
54 | builtins.print = print
55 |
56 | world_size = len(args.gpu_ids)
57 | gpu_id = args.gpu_ids[rank]
58 | # dist.init_process_group(
59 | # backend="nccl", rank=rank, world_size=world_size,
60 | # init_method=f"tcp://{args.master_addr}:{args.master_port}",
61 | # )
62 | # print(f"| distributed init on worker {rank}/{world_size}. "
63 | # f"using gpu: {gpu_id}")
64 | torch.cuda.set_device(gpu_id)
65 |
66 | inference_solver = FlexARInferenceSolver(
67 | model_path=args.pretrained_path, precision=args.precision, target_size=args.target_size
68 | )
69 |
70 | barrier.wait()
71 |
72 | while True:
73 | if response_queue is not None:
74 | response_queue.put(Ready())
75 | try:
76 | existing_images, chatbot, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k = request_queue.get()
77 |
78 | print(chatbot)
79 |
80 | random_seed(seed=seed)
81 |
82 | generated = inference_solver.generate(
83 | existing_images,
84 | chatbot,
85 | max_gen_len,
86 | gen_t,
87 | logits_processor=inference_solver.create_logits_processor(
88 | cfg=cfg, text_top_k=text_top_k, image_top_k=image_top_k
89 | ),
90 | )
91 |
92 | stream_response = {"text": generated[0], "image": generated[1], "end_of_content": True}
93 | print(generated[1])
94 | if response_queue is not None:
95 | response_queue.put(stream_response)
96 |
97 | except Exception:
98 | print(traceback.format_exc())
99 | response_queue.put(ModelFailure())
100 |
101 |
102 | def gradio_worker(
103 | request_queues: List[mp.Queue],
104 | response_queue: mp.Queue,
105 | args: argparse.Namespace,
106 | barrier: mp.Barrier,
107 | ) -> None:
108 | """
109 | The gradio worker is responsible for displaying the WebUI and relay the
110 | requests to model workers. It should be launched only once.
111 |
112 | Args:
113 | request_queues (List[mp.Queue]): A list of request queues (one for
114 | each model worker).
115 | args (argparse.Namespace): All command line arguments.
116 | barrier (multiprocessing.Barrier): A barrier used to delay the start
117 | of Web UI to be after the start of the model.
118 | """
119 |
120 | def check_input_sanity(text_input: str, new_images):
121 | if new_images is None:
122 | new_images = []
123 |
124 | print(new_images)
125 |
126 | if text_input.count("<|image|>") != len(new_images):
127 | raise gr.Error("please make sure that you have the same number of image inputs and <|image|> tokens")
128 |
129 | def show_user_input(text_input, new_images, chatbot, chatbot_display, existing_images):
130 |
131 | existing_images = [] if existing_images is None else existing_images
132 | new_images = [] if new_images is None else new_images
133 |
134 | return (
135 | "",
136 | [],
137 | chatbot + [[text_input, None]],
138 | chatbot_display + [[text_input, None]],
139 | existing_images + new_images,
140 | )
141 |
142 | def stream_model_output(
143 | existing_images, chatbot, chatbot_display, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k
144 | ):
145 |
146 | existing_images = [] if existing_images is None else existing_images
147 |
148 | while True:
149 | content_piece = response_queue.get()
150 | if isinstance(content_piece, Ready):
151 | break
152 | for queue in request_queues:
153 | queue.put(
154 | ([_[0] for _ in existing_images], chatbot, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k)
155 | )
156 | while True:
157 | content_piece = response_queue.get()
158 | if isinstance(content_piece, ModelFailure):
159 | raise RuntimeError
160 | chatbot_display[-1][1] = content_piece["text"].replace("<", "<").replace(">", ">")
161 | if content_piece["end_of_content"]:
162 | chatbot[-1][1] = content_piece["text"]
163 | chatbot_display[-1][1] = content_piece["text"]
164 | yield chatbot, chatbot_display, existing_images + [(_, None) for _ in content_piece["image"]]
165 | break
166 | else:
167 | yield chatbot, chatbot_display, []
168 |
169 | def clear():
170 | chatbot = []
171 | chatbot_display = []
172 | text_input = ""
173 | return chatbot, chatbot_display, text_input
174 |
175 | with gr.Blocks(css="#image_input {height: 100% !important}") as demo:
176 | gr.Markdown("# Lumina-mGPT Demo\n")
177 | with gr.Row() as r:
178 | with gr.Column(scale=1):
179 | existing_images = gr.Gallery(value=[], label="Existing Images", interactive=False)
180 | chatbot = gr.Chatbot(visible=False)
181 | chatbot_display = gr.Chatbot()
182 | with gr.Column(scale=1):
183 | new_images = gr.Gallery(value=[], label="Image Inputs", interactive=True)
184 | text_input = gr.Textbox()
185 | submit_button = gr.Button("Submit", variant="primary")
186 | clear_button = gr.ClearButton([existing_images, chatbot, chatbot_display, text_input, new_images])
187 | with gr.Row():
188 | with gr.Column(scale=1):
189 | max_gen_len = gr.Slider(
190 | minimum=1,
191 | maximum=5000,
192 | value=2048,
193 | interactive=True,
194 | label="max new tokens",
195 | )
196 | with gr.Column(scale=1):
197 | seed = gr.Slider(
198 | minimum=0,
199 | maximum=int(1e5),
200 | value=1,
201 | step=1,
202 | interactive=True,
203 | label="Seed (0 for random)",
204 | )
205 | with gr.Row():
206 | with gr.Column(scale=1):
207 | gen_t = gr.Slider(
208 | minimum=0.0,
209 | maximum=4.0,
210 | value=1.0,
211 | interactive=True,
212 | label="gen_t",
213 | )
214 | with gr.Column(scale=1):
215 | cfg = gr.Slider(
216 | minimum=0.0,
217 | maximum=16.0,
218 | value=1.0,
219 | interactive=True,
220 | label="cfg",
221 | )
222 | with gr.Row():
223 | with gr.Column(scale=1):
224 | image_top_k = gr.Slider(
225 | minimum=0,
226 | maximum=8192,
227 | value=2000,
228 | interactive=True,
229 | label="Image Top-k",
230 | )
231 | with gr.Column(scale=1):
232 | text_top_k = gr.Slider(
233 | minimum=0,
234 | maximum=9999,
235 | value=5,
236 | interactive=True,
237 | label="Text Top-k",
238 | )
239 |
240 | text_input.submit(check_input_sanity, [text_input, new_images], []).success(
241 | show_user_input,
242 | [text_input, new_images, chatbot, chatbot_display, existing_images],
243 | [text_input, new_images, chatbot, chatbot_display, existing_images],
244 | ).success(
245 | stream_model_output,
246 | [existing_images, chatbot, chatbot_display, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k],
247 | [chatbot, chatbot_display, existing_images],
248 | )
249 | submit_button.click(check_input_sanity, [text_input, new_images], []).success(
250 | show_user_input,
251 | [text_input, new_images, chatbot, chatbot_display, existing_images],
252 | [text_input, new_images, chatbot, chatbot_display, existing_images],
253 | ).success(
254 | stream_model_output,
255 | [existing_images, chatbot, chatbot_display, max_gen_len, seed, gen_t, cfg, image_top_k, text_top_k],
256 | [chatbot, chatbot_display, existing_images],
257 | )
258 | barrier.wait()
259 | demo.queue(api_open=True).launch(
260 | share=True,
261 | server_name="0.0.0.0",
262 | )
263 |
264 |
265 | if __name__ == "__main__":
266 | parser = argparse.ArgumentParser("X-LLM-X Chat Demo")
267 | group = parser.add_mutually_exclusive_group()
268 | group.add_argument(
269 | "--gpu_ids",
270 | type=int,
271 | nargs="+",
272 | help="A list of space-separated gpu ids to run the model on. "
273 | "The model will span across GPUs in tensor-parallel mode.",
274 | )
275 | group.add_argument(
276 | "--n_gpus",
277 | type=int,
278 | default=1,
279 | help="Number of GPUs to run the model on. Equivalent to " "--gpu_ids 0 1 2 ... n-1",
280 | )
281 | parser.add_argument("--pretrained_path", type=str, required=True, help="Path to the model checkpoints.")
282 | parser.add_argument(
283 | "--precision",
284 | type=str,
285 | choices=["fp16", "bf16"],
286 | default="bf16",
287 | help="The dtype used for model weights and inference.",
288 | )
289 | parser.add_argument(
290 | "--target_size", type=int, default=768, choices=[512, 768, 1024], help="The target image generation size."
291 | )
292 | args = parser.parse_args()
293 |
294 | # check and setup gpu_ids to use
295 | if args.gpu_ids is None:
296 | if args.n_gpus is None:
297 | args.n_gpus = 1
298 | assert args.n_gpus > 0, "The demo currently must run on a positive number of GPUs."
299 | args.gpu_ids = list(range(args.n_gpus))
300 |
301 | assert len(args.gpu_ids) == 1, "Currently only supports running on a single GPU."
302 |
303 | # using the default "fork" method messes up some imported libs (e.g.,
304 | # pandas)
305 | mp.set_start_method("spawn")
306 |
307 | # setup the queues and start the model workers
308 | request_queues = []
309 | response_queue = mp.Queue()
310 | worker_processes = []
311 | barrier = mp.Barrier(len(args.gpu_ids) + 1)
312 | for rank, gpu_id in enumerate(args.gpu_ids):
313 | request_queue = mp.Queue()
314 | rank_response_queue = response_queue if rank == 0 else None
315 | process = mp.Process(
316 | target=model_worker,
317 | args=(rank, args, barrier, request_queue, rank_response_queue),
318 | )
319 | process.start()
320 | worker_processes.append(process)
321 | request_queues.append(request_queue)
322 |
323 | gradio_worker(request_queues, response_queue, args, barrier)
324 |
--------------------------------------------------------------------------------
/lumina_mgpt/demos/demo_image2image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0])
5 |
6 | import argparse
7 | import builtins
8 | import datetime
9 | import multiprocessing as mp
10 | import traceback
11 | from typing import List, Optional
12 |
13 | import gradio as gr
14 | import torch
15 |
16 | from inference_solver import FlexARInferenceSolver
17 | from xllmx.util.misc import random_seed
18 |
19 |
20 | class Ready:
21 | pass
22 |
23 |
24 | class ModelFailure:
25 | pass
26 |
27 |
28 | def model_worker(
29 | rank: int,
30 | args: argparse.Namespace,
31 | barrier: mp.Barrier,
32 | request_queue: mp.Queue,
33 | response_queue: Optional[mp.Queue] = None,
34 | ) -> None:
35 | """
36 | The worker function that manipulates the GPU to run the inference.
37 | Exact n_gpu workers are started, with each one operating on a separate GPU.
38 |
39 | Args:
40 | rank (int): Distributed rank of the worker.
41 | args (argparse.Namespace): All command line arguments.
42 | barrier (multiprocessing.Barrier): A barrier used to delay the start
43 | of Web UI to be after the start of the model.
44 | """
45 |
46 | builtin_print = builtins.print
47 |
48 | def print(*args, **kwargs):
49 | kwargs["flush"] = True
50 | now = datetime.datetime.now().time()
51 | builtin_print("[{}] ".format(now), end="") # print with time stamp
52 | builtin_print(*args, **kwargs)
53 |
54 | builtins.print = print
55 |
56 | world_size = len(args.gpu_ids)
57 | gpu_id = args.gpu_ids[rank]
58 | # dist.init_process_group(
59 | # backend="nccl", rank=rank, world_size=world_size,
60 | # init_method=f"tcp://{args.master_addr}:{args.master_port}",
61 | # )
62 | # print(f"| distributed init on worker {rank}/{world_size}. "
63 | # f"using gpu: {gpu_id}")
64 | torch.cuda.set_device(gpu_id)
65 |
66 | inference_solver = FlexARInferenceSolver(
67 | model_path=args.pretrained_path, precision=args.precision, target_size=args.target_size
68 | )
69 |
70 | barrier.wait()
71 |
72 | while True:
73 | if response_queue is not None:
74 | response_queue.put(Ready())
75 | try:
76 | prompt, input_image, task, seed, gen_t, cfg, image_top_k = request_queue.get()
77 |
78 | random_seed(seed=seed)
79 |
80 | if task == "Semantic Segmentation":
81 | prompt = "Segmentic segmentation."
82 | elif task == "Depth Estimation":
83 | prompt = "Depth estimation."
84 | elif task == "Surface Norm Estimation":
85 | prompt = "Surface normal estimation."
86 | elif task == "Human Pose Estimation":
87 | prompt = "Human pose estimation."
88 | elif task == "Detection":
89 | prompt = "Detect: " + prompt
90 | elif task == "Editing":
91 | pass
92 | elif task == "Condition to Image":
93 | prompt = f"Generate an image according to the provided image, and according to the following caption:\n{prompt}" # noqa
94 |
95 | prompt = prompt + " <|image|>"
96 | print(prompt)
97 |
98 | generated = inference_solver.generate(
99 | [input_image],
100 | [[prompt, None]],
101 | 5000,
102 | gen_t,
103 | logits_processor=inference_solver.create_logits_processor(
104 | cfg=cfg, text_top_k=5, image_top_k=image_top_k
105 | ),
106 | )
107 |
108 | stream_response = {"text": generated[0], "image": generated[1], "prompt": prompt, "end_of_content": True}
109 |
110 | print(generated[1])
111 |
112 | if response_queue is not None:
113 | print("here")
114 | response_queue.put(stream_response)
115 |
116 | except Exception:
117 | print(traceback.format_exc())
118 | response_queue.put(ModelFailure())
119 |
120 |
121 | def gradio_worker(
122 | request_queues: List[mp.Queue],
123 | response_queue: mp.Queue,
124 | args: argparse.Namespace,
125 | barrier: mp.Barrier,
126 | ) -> None:
127 | """
128 | The gradio worker is responsible for displaying the WebUI and relay the
129 | requests to model workers. It should be launched only once.
130 |
131 | Args:
132 | request_queues (List[mp.Queue]): A list of request queues (one for
133 | each model worker).
134 | args (argparse.Namespace): All command line arguments.
135 | barrier (multiprocessing.Barrier): A barrier used to delay the start
136 | of Web UI to be after the start of the model.
137 | """
138 |
139 | def check_input_sanity(text_input: str):
140 | if len(text_input) > 512:
141 | raise gr.Error("please do not send more than 1024 characters to this demo")
142 | if text_input.count("<|image|>") != 0:
143 | raise gr.Error("please do not send <|image|> tokens to this demo")
144 |
145 | def updateUIForTask(task):
146 | if task == "Detection":
147 | return gr.update(label="Object to Detect", visible=True)
148 | elif task == "Editing":
149 | return gr.update(label="Editing Prompt", visible=True)
150 | elif task == "Condition to Image":
151 | return gr.update(label="Image Prompt", visible=True)
152 | else:
153 | return gr.update(visible=False, value="")
154 |
155 | def stream_model_output(prompt, input_image, task, seed, gen_t, cfg, image_top_k):
156 |
157 | while True:
158 | content_piece = response_queue.get()
159 | if isinstance(content_piece, Ready):
160 | break
161 | for queue in request_queues:
162 | queue.put((prompt, input_image, task, seed, gen_t, cfg, image_top_k))
163 | while True:
164 | content_piece = response_queue.get()
165 | if isinstance(content_piece, ModelFailure):
166 | raise RuntimeError
167 | if content_piece["end_of_content"]:
168 | yield content_piece["image"][0], content_piece["prompt"]
169 | break
170 | else:
171 | yield None, None
172 |
173 | with gr.Blocks(css="#image_input {height: 100% !important}") as demo:
174 | gr.Markdown("# Lumina-mGPT Image2Image Demo\n")
175 | with gr.Row() as r:
176 | with gr.Column(scale=1):
177 | with gr.Row():
178 | input_image = gr.Image(type="pil", label="Input Image", interactive=True, elem_id="image_input")
179 | with gr.Row():
180 | l_tasks = [
181 | "Semantic Segmentation",
182 | "Depth Estimation",
183 | "Surface Norm Estimation",
184 | "Detection",
185 | "Human Pose Estimation",
186 | "Editing",
187 | "Condition to Image",
188 | ]
189 | task = gr.Dropdown(value=f"Semantic Segmentation", choices=l_tasks, label="tasks")
190 | with gr.Row():
191 | prompt = gr.Textbox(visible=False)
192 | with gr.Row():
193 | with gr.Column(scale=1):
194 | seed = gr.Slider(
195 | minimum=0,
196 | maximum=int(1e5),
197 | value=1,
198 | step=1,
199 | interactive=True,
200 | label="Seed (0 for random)",
201 | )
202 | with gr.Column(scale=1):
203 | gen_t = gr.Slider(
204 | minimum=0.0,
205 | maximum=4.0,
206 | value=1.0,
207 | interactive=True,
208 | label="gen_t",
209 | )
210 | with gr.Row():
211 | with gr.Column(scale=1):
212 | cfg = gr.Slider(
213 | minimum=0.0,
214 | maximum=16.0,
215 | value=1.0,
216 | interactive=True,
217 | label="cfg",
218 | )
219 | with gr.Column(scale=1):
220 | image_top_k = gr.Slider(
221 | minimum=0,
222 | maximum=8192,
223 | value=200,
224 | interactive=True,
225 | label="Image Top-k",
226 | )
227 | with gr.Row():
228 | submit_button = gr.Button("Submit", variant="primary")
229 |
230 | with gr.Column():
231 | output_img = gr.Image(
232 | label="Generated image",
233 | interactive=False,
234 | )
235 | real_prompt = gr.Textbox(
236 | label="Real Prompt",
237 | interactive=False,
238 | visible=False,
239 | lines=5,
240 | show_label=True,
241 | show_copy_button=True,
242 | )
243 |
244 | task.change(updateUIForTask, task, prompt)
245 |
246 | submit_button.click(check_input_sanity, [prompt], []).success(
247 | stream_model_output, [prompt, input_image, task, seed, gen_t, cfg, image_top_k], [output_img, real_prompt]
248 | ).success(lambda: gr.update(visible=True), [], [real_prompt])
249 | barrier.wait()
250 | demo.queue(api_open=True).launch(
251 | share=True,
252 | server_name="0.0.0.0",
253 | )
254 |
255 |
256 | if __name__ == "__main__":
257 | parser = argparse.ArgumentParser("X-LLM-X Chat Demo")
258 | group = parser.add_mutually_exclusive_group()
259 | group.add_argument(
260 | "--gpu_ids",
261 | type=int,
262 | nargs="+",
263 | help="A list of space-separated gpu ids to run the model on. "
264 | "The model will span across GPUs in tensor-parallel mode.",
265 | )
266 | group.add_argument(
267 | "--n_gpus",
268 | type=int,
269 | default=1,
270 | help="Number of GPUs to run the model on. Equivalent to " "--gpu_ids 0 1 2 ... n-1",
271 | )
272 | parser.add_argument("--pretrained_path", type=str, required=True, help="Path to the model checkpoints.")
273 | parser.add_argument(
274 | "--precision",
275 | type=str,
276 | choices=["fp16", "bf16"],
277 | default="bf16",
278 | help="The dtype used for model weights and inference.",
279 | )
280 | parser.add_argument(
281 | "--target_size", type=int, default=768, choices=[512, 768, 1024], help="The target image generation size."
282 | )
283 | args = parser.parse_args()
284 |
285 | # check and setup gpu_ids to use
286 | if args.gpu_ids is None:
287 | if args.n_gpus is None:
288 | args.n_gpus = 1
289 | assert args.n_gpus > 0, "The demo currently must run on a positive number of GPUs."
290 | args.gpu_ids = list(range(args.n_gpus))
291 |
292 | assert len(args.gpu_ids) == 1, "Currently only supports running on a single GPU."
293 |
294 | # using the default "fork" method messes up some imported libs (e.g.,
295 | # pandas)
296 | mp.set_start_method("spawn")
297 |
298 | # setup the queues and start the model workers
299 | request_queues = []
300 | response_queue = mp.Queue()
301 | worker_processes = []
302 | barrier = mp.Barrier(len(args.gpu_ids) + 1)
303 | for rank, gpu_id in enumerate(args.gpu_ids):
304 | request_queue = mp.Queue()
305 | rank_response_queue = response_queue if rank == 0 else None
306 | process = mp.Process(
307 | target=model_worker,
308 | args=(rank, args, barrier, request_queue, rank_response_queue),
309 | )
310 | process.start()
311 | worker_processes.append(process)
312 | request_queues.append(request_queue)
313 |
314 | gradio_worker(request_queues, response_queue, args, barrier)
315 |
--------------------------------------------------------------------------------
/lumina_mgpt/demos/demo_image_generation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0])
5 |
6 | import argparse
7 | import builtins
8 | import datetime
9 | import multiprocessing as mp
10 | import traceback
11 | from typing import List, Optional
12 |
13 | import gradio as gr
14 | import torch
15 |
16 | from data.item_processor import generate_crop_size_list
17 | from inference_solver import FlexARInferenceSolver
18 | from xllmx.util.misc import random_seed
19 |
20 |
21 | class Ready:
22 | pass
23 |
24 |
25 | class ModelFailure:
26 | pass
27 |
28 |
29 | @torch.no_grad()
30 | def model_worker(
31 | rank: int,
32 | args: argparse.Namespace,
33 | barrier: mp.Barrier,
34 | request_queue: mp.Queue,
35 | response_queue: Optional[mp.Queue] = None,
36 | ) -> None:
37 | """
38 | The worker function that manipulates the GPU to run the inference.
39 | Exact n_gpu workers are started, with each one operating on a separate GPU.
40 |
41 | Args:
42 | rank (int): Distributed rank of the worker.
43 | args (argparse.Namespace): All command line arguments.
44 | barrier (multiprocessing.Barrier): A barrier used to delay the start
45 | of Web UI to be after the start of the model.
46 | """
47 |
48 | builtin_print = builtins.print
49 |
50 | def print(*args, **kwargs):
51 | kwargs["flush"] = True
52 | now = datetime.datetime.now().time()
53 | builtin_print("[{}] ".format(now), end="") # print with time stamp
54 | builtin_print(*args, **kwargs)
55 |
56 | builtins.print = print
57 |
58 | world_size = len(args.gpu_ids)
59 | gpu_id = args.gpu_ids[rank]
60 | # dist.init_process_group(
61 | # backend="nccl", rank=rank, world_size=world_size,
62 | # init_method=f"tcp://{args.master_addr}:{args.master_port}",
63 | # )
64 | # print(f"| distributed init on worker {rank}/{world_size}. "
65 | # f"using gpu: {gpu_id}")
66 | torch.cuda.set_device(gpu_id)
67 |
68 | inference_solver = FlexARInferenceSolver(
69 | model_path=args.pretrained_path,
70 | precision=args.precision,
71 | )
72 |
73 | barrier.wait()
74 |
75 | while True:
76 | if response_queue is not None:
77 | response_queue.put(Ready())
78 | try:
79 | prompt, resolution, seed, gen_t, cfg, image_top_k = request_queue.get()
80 |
81 | random_seed(seed=seed)
82 |
83 | prompt = f"Generate an image of {resolution} according to the following prompt:\n{prompt}"
84 | print(prompt)
85 |
86 | generated = inference_solver.generate(
87 | [],
88 | [[prompt, None]],
89 | 5000,
90 | gen_t,
91 | logits_processor=inference_solver.create_logits_processor(
92 | cfg=cfg, text_top_k=5, image_top_k=image_top_k
93 | ),
94 | )
95 |
96 | print("*" * 100)
97 | print(generated[1])
98 |
99 | stream_response = {"text": generated[0], "image": generated[1], "prompt": prompt, "end_of_content": True}
100 |
101 | print(generated[1])
102 |
103 | if response_queue is not None:
104 | print("here")
105 | response_queue.put(stream_response)
106 |
107 | except Exception:
108 | print(traceback.format_exc())
109 | response_queue.put(ModelFailure())
110 |
111 |
112 | def gradio_worker(
113 | request_queues: List[mp.Queue],
114 | response_queue: mp.Queue,
115 | args: argparse.Namespace,
116 | barrier: mp.Barrier,
117 | ) -> None:
118 | """
119 | The gradio worker is responsible for displaying the WebUI and relay the
120 | requests to model workers. It should be launched only once.
121 |
122 | Args:
123 | request_queues (List[mp.Queue]): A list of request queues (one for
124 | each model worker).
125 | args (argparse.Namespace): All command line arguments.
126 | barrier (multiprocessing.Barrier): A barrier used to delay the start
127 | of Web UI to be after the start of the model.
128 | """
129 |
130 | def check_input_sanity(text_input: str):
131 | if len(text_input) > 1024:
132 | raise gr.Error("please do not send more than 1024 characters to this demo")
133 | if text_input.count("<|image|>") != 0:
134 | raise gr.Error("please do not send <|image|> tokens to this demo")
135 |
136 | def stream_model_output(prompt, resolution, seed, gen_t, cfg, image_top_k):
137 |
138 | while True:
139 | content_piece = response_queue.get()
140 | if isinstance(content_piece, Ready):
141 | break
142 | for queue in request_queues:
143 | queue.put((prompt, resolution, seed, gen_t, cfg, image_top_k))
144 | while True:
145 | content_piece = response_queue.get()
146 | if isinstance(content_piece, ModelFailure):
147 | raise RuntimeError
148 | if content_piece["end_of_content"]:
149 | yield content_piece["image"][0], content_piece["prompt"]
150 | break
151 | else:
152 | yield None, None
153 |
154 | def show_real_prompt():
155 | return gr.update(visible=True)
156 |
157 | with gr.Blocks(css="#image_input {height: 100% !important}") as demo:
158 | gr.Markdown("# Lumina-mGPT Image Generation Demo\n")
159 | with gr.Row() as r:
160 | with gr.Column(scale=1):
161 | prompt = gr.Textbox(lines=3, interactive=True, label="Prompt")
162 | with gr.Row():
163 | patch_size = 32
164 | res_choices = generate_crop_size_list((args.target_size // patch_size) ** 2, patch_size)
165 | res_choices = [f"{w}x{h}" for w, h in res_choices]
166 | assert f"{args.target_size}x{args.target_size}" in res_choices
167 | resolution = gr.Dropdown(
168 | value=f"{args.target_size}x{args.target_size}", choices=res_choices, label="Resolution"
169 | )
170 | with gr.Row():
171 | with gr.Column(scale=1):
172 | seed = gr.Slider(
173 | minimum=0,
174 | maximum=int(1e5),
175 | value=300,
176 | step=1,
177 | interactive=True,
178 | label="Seed (0 for random)",
179 | )
180 | with gr.Column(scale=1):
181 | gen_t = gr.Slider(
182 | minimum=0.0,
183 | maximum=4.0,
184 | value=1.0,
185 | interactive=True,
186 | label="gen_t",
187 | )
188 | with gr.Row():
189 | with gr.Column(scale=1):
190 | cfg = gr.Slider(
191 | minimum=0.0,
192 | maximum=16.0,
193 | value=3.0,
194 | interactive=True,
195 | label="cfg",
196 | )
197 | with gr.Column(scale=1):
198 | image_top_k = gr.Slider(
199 | minimum=0,
200 | maximum=8192,
201 | value=4000,
202 | interactive=True,
203 | label="Image Top-k",
204 | )
205 | submit_button = gr.Button("Submit", variant="primary")
206 |
207 | with gr.Column():
208 | output_img = gr.Image(
209 | label="Generated image",
210 | interactive=False,
211 | )
212 | real_prompt = gr.Textbox(
213 | label="Real Prompt", interactive=False, visible=False, show_label=True, show_copy_button=True
214 | )
215 |
216 | prompt.submit(check_input_sanity, [prompt], []).success(
217 | stream_model_output, [prompt, resolution, seed, gen_t, cfg, image_top_k], [output_img, real_prompt]
218 | )
219 | submit_button.click(check_input_sanity, [prompt], []).success(
220 | stream_model_output, [prompt, resolution, seed, gen_t, cfg, image_top_k], [output_img, real_prompt]
221 | ).success(show_real_prompt, [], [real_prompt])
222 | barrier.wait()
223 | demo.queue(api_open=True).launch(
224 | share=True,
225 | server_name="0.0.0.0",
226 | )
227 |
228 |
229 | if __name__ == "__main__":
230 | parser = argparse.ArgumentParser("X-LLM-X Chat Demo")
231 | group = parser.add_mutually_exclusive_group()
232 | group.add_argument(
233 | "--gpu_ids",
234 | type=int,
235 | nargs="+",
236 | help="A list of space-separated gpu ids to run the model on. "
237 | "The model will span across GPUs in tensor-parallel mode.",
238 | )
239 | group.add_argument(
240 | "--n_gpus",
241 | type=int,
242 | default=1,
243 | help="Number of GPUs to run the model on. Equivalent to " "--gpu_ids 0 1 2 ... n-1",
244 | )
245 | parser.add_argument("--pretrained_path", type=str, required=True, help="Path to the model checkpoints.")
246 | parser.add_argument(
247 | "--precision",
248 | type=str,
249 | choices=["fp16", "bf16"],
250 | default="bf16",
251 | help="The dtype used for model weights and inference.",
252 | )
253 | parser.add_argument(
254 | "--target_size", type=int, default=768, choices=[512, 768, 1024], help="The target image generation size."
255 | )
256 | args = parser.parse_args()
257 |
258 | # check and setup gpu_ids to use
259 | if args.gpu_ids is None:
260 | if args.n_gpus is None:
261 | args.n_gpus = 1
262 | assert args.n_gpus > 0, "The demo currently must run on a positive number of GPUs."
263 | args.gpu_ids = list(range(args.n_gpus))
264 |
265 | assert len(args.gpu_ids) == 1, "Currently only supports running on a single GPU."
266 |
267 | # using the default "fork" method messes up some imported libs (e.g.,
268 | # pandas)
269 | mp.set_start_method("spawn")
270 |
271 | # setup the queues and start the model workers
272 | request_queues = []
273 | response_queue = mp.Queue()
274 | worker_processes = []
275 | barrier = mp.Barrier(len(args.gpu_ids) + 1)
276 | for rank, gpu_id in enumerate(args.gpu_ids):
277 | request_queue = mp.Queue()
278 | rank_response_queue = response_queue if rank == 0 else None
279 | process = mp.Process(
280 | target=model_worker,
281 | args=(rank, args, barrier, request_queue, rank_response_queue),
282 | )
283 | process.start()
284 | worker_processes.append(process)
285 | request_queues.append(request_queue)
286 |
287 | gradio_worker(request_queues, response_queue, args, barrier)
288 |
--------------------------------------------------------------------------------
/lumina_mgpt/exps/7B.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lr=2e-5
4 | wd=0.1
5 | dropout=0.05
6 | z_loss_weight=1e-5
7 |
8 | data_config=configs/data/sample.yaml
9 |
10 | exp_name=7B
11 | mkdir -p output/"$exp_name"
12 |
13 |
14 | python -u finetune_solver.py \
15 | --model_size 7B \
16 | --batch_size 8 \
17 | --accum_iter 1 \
18 | --epochs 2 \
19 | --warmup_epochs 0.01 \
20 | --lr ${lr} \
21 | --min_lr ${lr} \
22 | --wd ${wd} \
23 | --clip_grad 4 \
24 | --data_config $data_config \
25 | --cache_ann_on_disk \
26 | --num_workers 8 \
27 | --output_dir output/"$exp_name" \
28 | --save_iteration_interval 1000 \
29 | --checkpointing \
30 | --max_seq_len 4096 \
31 | --unmask_image_logits \
32 | --dropout ${dropout} \
33 | --z_loss_weight ${z_loss_weight} \
34 | 2>&1 | tee -a output/"$exp_name"/output.log
35 |
36 | echo "exp name: $exp_name"
37 |
--------------------------------------------------------------------------------
/lumina_mgpt/finetune_solver.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from typing import List, Tuple
3 |
4 | from accelerate import init_empty_weights
5 | import torch
6 |
7 | from model import ChameleonXLLMXConfig, ChameleonXLLMXForConditionalGeneration
8 | from xllmx.data.item_processor import ItemProcessorBase
9 | from xllmx.solvers.finetune import FinetuneSolverBase
10 |
11 |
12 | class ItemProcessor(ItemProcessorBase):
13 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]:
14 | assert training_mode
15 |
16 | if "token" in data_item and "label" in data_item:
17 | data_item = data_item
18 | else:
19 | assert "file" in data_item
20 | with open(data_item["file"], "rb") as f:
21 | data_item = pickle.load(f)
22 |
23 | tokens = data_item["token"]
24 | labels = data_item["label"]
25 | assert len(tokens) == len(labels)
26 |
27 | return tokens, labels
28 |
29 | def predict_item_token_length(self, data_item: dict) -> int:
30 | if "token" in data_item:
31 | return len(data_item["token"])
32 | elif "len" in data_item:
33 | return data_item["len"]
34 | else:
35 | raise ValueError()
36 |
37 |
38 | class Solver(FinetuneSolverBase):
39 | @classmethod
40 | def get_args_parser(cls):
41 | parser = super().get_args_parser()
42 | # task-specific parameters
43 | parser.add_argument("--max_seq_len", default=4096, type=int, help="max token length")
44 | parser.add_argument("--mask_image_logits", default=True)
45 | parser.add_argument("--unmask_image_logits", action="store_false", dest="mask_image_logits")
46 | parser.add_argument("--dropout", type=float, default=0.0)
47 | parser.add_argument("--z_loss_weight", type=float, default=0.0)
48 | parser.add_argument("--model_size", type=str, default="7B", choices=["7B", "34B"])
49 | return parser
50 |
51 | def _model_func(
52 | self,
53 | init_from: str,
54 | ) -> (ChameleonXLLMXForConditionalGeneration, None):
55 |
56 | # Only instantiate the model on rank0
57 | # Other ranks will receive the model weights from rank0 during FSDP wrapping (through `sync_module_states`)
58 | # See https://github.com/pytorch/pytorch/issues/105840
59 | if self.dp_rank == 0:
60 | model = ChameleonXLLMXForConditionalGeneration.from_pretrained(
61 | init_from,
62 | max_position_embeddings=self.args.max_seq_len,
63 | mask_image_logits=self.args.mask_image_logits,
64 | dropout=self.args.dropout,
65 | z_loss_weight=self.args.z_loss_weight,
66 | torch_dtype=torch.bfloat16,
67 | device_map="cpu",
68 | )
69 | else:
70 | with init_empty_weights():
71 | config = ChameleonXLLMXConfig.from_pretrained(
72 | init_from,
73 | max_position_embeddings=self.args.max_seq_len,
74 | mask_image_logits=self.args.mask_image_logits,
75 | dropout=self.args.dropout,
76 | z_loss_weight=self.args.z_loss_weight,
77 | torch_dtype=torch.bfloat16,
78 | )
79 | model = ChameleonXLLMXForConditionalGeneration(config)
80 |
81 | del model.model.vqmodel
82 |
83 | return model, None
84 |
85 | def _item_processor_func(self) -> ItemProcessorBase:
86 | return ItemProcessor()
87 |
88 | def _make_and_save_starting_point(self, save_path: str) -> None:
89 |
90 | pretrained_name = {
91 | "7B": "Alpha-VLLM/Chameleon_7B_mGPT",
92 | "34B": "Alpha-VLLM/Chameleon_34B_mGPT",
93 | }[self.args.model_size]
94 |
95 | model = ChameleonXLLMXForConditionalGeneration.from_pretrained(
96 | pretrained_name,
97 | max_position_embeddings=self.args.max_seq_len,
98 | mask_image_logits=self.args.mask_image_logits,
99 | dropout=self.args.dropout,
100 | z_loss_weight=self.args.z_loss_weight,
101 | torch_dtype=torch.bfloat16,
102 | device_map="cpu",
103 | )
104 |
105 | image_tokens = model.model.vocabulary_mapping.image_tokens
106 | model.lm_head.weight.data[image_tokens] = torch.zeros_like(model.lm_head.weight.data[image_tokens])
107 |
108 | model.save_pretrained(save_path, max_shard_size="10GB")
109 |
110 |
111 | if __name__ == "__main__":
112 | args = Solver.get_args_parser().parse_args()
113 | solver = Solver(args)
114 | solver.run()
115 |
--------------------------------------------------------------------------------
/lumina_mgpt/generate_examples/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0])
5 | import argparse
6 |
7 | from PIL import Image
8 | import torch
9 |
10 | from inference_solver import FlexARInferenceSolver
11 | from xllmx.util.misc import random_seed
12 |
13 | if __name__ == "__main__":
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model_path", type=str, required=True)
17 | parser.add_argument("--save_path", type=str, required=True)
18 | parser.add_argument("--temperature", type=float)
19 | parser.add_argument("--top_k", type=int)
20 | parser.add_argument("--cfg", type=float)
21 | parser.add_argument("-n", type=int, default=5)
22 | parser.add_argument("--width", type=int, default=512)
23 | parser.add_argument("--height", type=int, default=512)
24 |
25 | args = parser.parse_args()
26 |
27 | print("args:\n", args)
28 |
29 | select_set1 = [
30 | "Image of a dog playing water, and a water fall is in the background.",
31 | "A family of asian people sitting around the dinner table, eating and laughing.",
32 | "A high-resolution photograph of a middle-aged woman with curly hair, wearing traditional Japanese kimono, smiling gently under a cherry blossom tree in full bloom.", # noqa
33 | "Image of a bustling downtown street in Tokyo at night, with neon signs, crowded sidewalks, and tall skyscrapers.", # noqa
34 | "Image of a quiet European village with cobblestone streets and colorful houses, under a clear blue sky.",
35 | ]
36 |
37 | l_prompts = select_set1
38 |
39 | t = args.temperature
40 | top_k = args.top_k
41 | cfg = args.cfg
42 | n = args.n
43 | w, h = args.width, args.height
44 |
45 | inference_solver = FlexARInferenceSolver(
46 | model_path=args.model_path,
47 | precision="bf16",
48 | )
49 |
50 | with torch.no_grad():
51 | l_generated_all = []
52 | for i, prompt in enumerate(l_prompts):
53 | for repeat_idx in range(n):
54 | random_seed(repeat_idx)
55 | generated = inference_solver.generate(
56 | images=[],
57 | qas=[[f"Generate an image of {w}x{h} according to the following prompt:\n{prompt}", None]],
58 | max_gen_len=8192,
59 | temperature=t,
60 | logits_processor=inference_solver.create_logits_processor(cfg=cfg, image_top_k=top_k),
61 | )
62 | try:
63 | l_generated_all.append(generated[1][0])
64 | except:
65 | l_generated_all.append(Image.new("RGB", (w, h)))
66 |
67 | result_image = inference_solver.create_image_grid(l_generated_all, len(l_prompts), n)
68 | result_image.save(args.save_path)
69 |
--------------------------------------------------------------------------------
/lumina_mgpt/inference_solver.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import math
4 | from typing import List, Optional, Union
5 |
6 | from PIL import Image
7 | import torch
8 | import transformers
9 | from transformers import GenerationConfig, TextStreamer
10 | from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList, LogitsWarper
11 |
12 | from data.item_processor import FlexARItemProcessor
13 | from model.chameleon import ChameleonForConditionalGeneration
14 |
15 |
16 | class LLMImageStartTriggeredUnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
17 | r"""
18 | Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores
19 | from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
20 | The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.
21 |
22 | See [the paper](https://arxiv.org/abs/2306.17806) for more information.
23 | """
24 |
25 | def __init__(
26 | self,
27 | guidance_scale: float,
28 | model,
29 | image_start_token_id,
30 | image_end_token_id,
31 | image_next_line_token_id,
32 | patch_size,
33 | unconditional_ids: Optional[torch.LongTensor] = None,
34 | unconditional_attention_mask: Optional[torch.LongTensor] = None,
35 | use_cache: Optional[bool] = True,
36 | ):
37 | self.guidance_scale = guidance_scale
38 | self.model = model
39 | self.unconditional_context_backup = {
40 | "input_ids": unconditional_ids,
41 | "attention_mask": unconditional_attention_mask,
42 | "use_cache": use_cache,
43 | "past_key_values": transformers.DynamicCache() if use_cache else None,
44 | "first_pass": True,
45 | }
46 | self.unconditional_context = None
47 |
48 | self.nums_image_start_tokens = None
49 |
50 | self.image_start_token_id = image_start_token_id
51 | self.image_end_token_id = image_end_token_id
52 | self.image_next_line_token_id = image_next_line_token_id
53 | self.image_start_token_id_index = None
54 | self.patch_size = patch_size
55 | self.h_latent_dim = None
56 | self.w_latent_dim = None
57 |
58 | def get_unconditional_logits(self, input_ids, image_start_token_id_index):
59 |
60 | if self.unconditional_context["first_pass"]:
61 | if self.unconditional_context["input_ids"] is None:
62 | self.unconditional_context["input_ids"] = input_ids[:, image_start_token_id_index:]
63 | if self.unconditional_context["attention_mask"] is None:
64 | self.unconditional_context["attention_mask"] = torch.ones_like(
65 | self.unconditional_context["input_ids"], dtype=torch.long
66 | )
67 | input_ids = self.unconditional_context["input_ids"]
68 | attention_mask = self.unconditional_context["attention_mask"]
69 | self.unconditional_context["first_pass"] = False
70 | else:
71 | attention_mask = torch.cat(
72 | [
73 | self.unconditional_context["attention_mask"],
74 | torch.ones_like(input_ids[:, -1:], dtype=torch.long),
75 | ],
76 | dim=1,
77 | )
78 | if not self.unconditional_context["use_cache"]:
79 | input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
80 | else:
81 | input_ids = input_ids[:, -1:]
82 | self.unconditional_context["input_ids"] = input_ids
83 | self.unconditional_context["attention_mask"] = attention_mask
84 |
85 | out = self.model(
86 | input_ids,
87 | attention_mask=attention_mask,
88 | use_cache=self.unconditional_context["use_cache"],
89 | past_key_values=self.unconditional_context["past_key_values"],
90 | )
91 | self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
92 |
93 | return out.logits
94 |
95 | def __call__(self, input_ids, scores):
96 | num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum()
97 | num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum()
98 |
99 | if num_image_start_tokens == num_image_end_tokens:
100 | self.h_latent_dim, self.w_latent_dim = None, None
101 | self.image_start_token_id_index = None
102 | self.unconditional_context = None
103 | return scores
104 |
105 | elif num_image_start_tokens == num_image_end_tokens + 1:
106 | if self.image_start_token_id_index is None:
107 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0][-1].item()
108 | new_token_num = len(input_ids[0][self.image_start_token_id_index + 1 :])
109 | if new_token_num >= 2:
110 | if self.h_latent_dim is None or self.w_latent_dim is None:
111 | h_grids, w_grids = (
112 | input_ids[0][self.image_start_token_id_index + 1] - 8804,
113 | input_ids[0][self.image_start_token_id_index + 2] - 8804,
114 | )
115 | self.h_latent_dim, self.w_latent_dim = h_grids * 2, w_grids * 2
116 |
117 | if self.unconditional_context is None:
118 | self.unconditional_context = copy.deepcopy(self.unconditional_context_backup)
119 |
120 | if self.guidance_scale == 1.0:
121 | return scores
122 |
123 | unconditional_logits = self.get_unconditional_logits(input_ids, self.image_start_token_id_index)[:, -1]
124 |
125 | scores_processed = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
126 | return scores_processed
127 |
128 | else:
129 | print("Something wrong in the decoding process.")
130 |
131 | return scores
132 |
133 |
134 | class MultiModalLogitsProcessor(LogitsProcessor):
135 |
136 | def __init__(
137 | self,
138 | image_start_token_id=None,
139 | image_end_token_id=None,
140 | image_next_line_token_id=None,
141 | patch_size=None,
142 | voc_size=None,
143 | ):
144 | self.image_start_token_id = image_start_token_id
145 | self.image_end_token_id = image_end_token_id
146 | self.image_next_line_token_id = image_next_line_token_id
147 | self.image_start_token_id_index = None
148 | self.patch_size = patch_size
149 | self.h_latent_dim = None
150 | self.w_latent_dim = None
151 |
152 | self.vocab_list = [i for i in range(voc_size)]
153 | self.image_token_list = [i for i in range(4, 8195 + 1)]
154 | self.suppress_tokens = torch.tensor(
155 | [x for x in self.vocab_list if x not in self.image_token_list], device="cuda"
156 | )
157 |
158 | self.vocab_tensor = torch.arange(voc_size, device="cuda")
159 | self.suppress_token_mask = torch.isin(self.vocab_tensor, self.suppress_tokens)
160 | self.new_line_force_token_mask = torch.isin(
161 | self.vocab_tensor, torch.tensor([self.image_next_line_token_id], device="cuda")
162 | )
163 | self.eos_image_force_token_mask = torch.isin(
164 | self.vocab_tensor, torch.tensor([self.image_end_token_id], device="cuda")
165 | )
166 |
167 | self.flag = False
168 | self.num_image_start_tokens = None
169 | self.num_image_end_tokens = None
170 |
171 | # @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
172 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
173 |
174 | self.num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum()
175 | self.num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum()
176 |
177 | # print(self.num_image_start_tokens, self.num_image_end_tokens)
178 |
179 | if self.num_image_start_tokens == self.num_image_end_tokens:
180 | self.h_latent_dim, self.w_latent_dim = None, None
181 | self.image_start_token_id_index = None
182 | return scores
183 |
184 | elif self.num_image_start_tokens == self.num_image_end_tokens + 1:
185 | if self.image_start_token_id_index is None:
186 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0]
187 | print(self.image_start_token_id_index)
188 | self.image_start_token_id_index = torch.where(input_ids[0] == self.image_start_token_id)[0][-1].item()
189 |
190 | new_token_num = len(input_ids[0][self.image_start_token_id_index + 1 :])
191 | # print(f"num new tokens: {new_token_num}")
192 | if new_token_num >= 2:
193 | if self.h_latent_dim is None or self.w_latent_dim is None:
194 | h_grids, w_grids = (
195 | input_ids[0][self.image_start_token_id_index + 1] - 8804,
196 | input_ids[0][self.image_start_token_id_index + 2] - 8804,
197 | )
198 | # print(f"h_grids: {h_grids}, w_grids: {w_grids}")
199 | self.h_latent_dim, self.w_latent_dim = h_grids * 2, w_grids * 2
200 | print(f"h_latent_dim: {self.h_latent_dim}, w_latent_dim: {self.w_latent_dim}")
201 |
202 | tokens = input_ids[0][self.image_start_token_id_index + 3 :]
203 | if (len(tokens) + 1) % (self.w_latent_dim + 1) == 0:
204 | new_line_constrained_scores = torch.full_like(scores, -math.inf)
205 | new_line_constrained_scores[:, self.image_next_line_token_id] = 0
206 | print(f"new line: {len(tokens)+1}")
207 | return new_line_constrained_scores
208 | elif (len(tokens) + 1) == (self.w_latent_dim + 1) * self.h_latent_dim + 1:
209 | eos_image_constrained_scores = torch.full_like(scores, -math.inf)
210 | eos_image_constrained_scores[:, self.image_end_token_id] = 0
211 | print(f"eos image: {len(tokens)+1}")
212 | return eos_image_constrained_scores
213 | elif (len(tokens) + 1) % (self.w_latent_dim + 1) != 0:
214 | image_constrained_scores = torch.where(self.suppress_token_mask, -float("inf"), scores)
215 | return image_constrained_scores
216 | else:
217 | print("Something wrong in the decoding process.")
218 |
219 | return scores
220 |
221 |
222 | class InterleavedTopKLogitsWarper(LogitsWarper):
223 | r"""
224 | [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
225 | with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
226 | """
227 |
228 | def __init__(
229 | self,
230 | image_top_k: int,
231 | text_top_k: int,
232 | image_start_token_id=None,
233 | image_end_token_id=None,
234 | filter_value: float = -float("Inf"),
235 | min_tokens_to_keep: int = 1,
236 | ):
237 | if not isinstance(text_top_k, int) or text_top_k <= 0:
238 | raise ValueError(f"`text_top_k` has to be a strictly positive integer, but is {text_top_k}")
239 | if not isinstance(image_top_k, int) or text_top_k <= 0:
240 | raise ValueError(f"`image_top_k` has to be a strictly positive integer, but is {image_top_k}")
241 |
242 | self.image_top_k = max(image_top_k, min_tokens_to_keep)
243 | self.text_top_k = max(text_top_k, min_tokens_to_keep)
244 | self.filter_value = filter_value
245 |
246 | self.image_start_token_id = image_start_token_id
247 | self.image_end_token_id = image_end_token_id
248 |
249 | self.flag = False
250 | self.num_image_start_tokens = None
251 | self.num_image_end_tokens = None
252 |
253 | # @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
254 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
255 |
256 | self.num_image_start_tokens = (input_ids[0] == self.image_start_token_id).sum()
257 | self.num_image_end_tokens = (input_ids[0] == self.image_end_token_id).sum()
258 |
259 | if self.num_image_start_tokens == self.num_image_end_tokens + 1:
260 | top_k = min(self.image_top_k, scores.size(-1))
261 | else:
262 | top_k = min(self.text_top_k, scores.size(-1)) # Safety check
263 | # Remove all tokens with a probability less than the last token of the top-k
264 | indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
265 | scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
266 | return scores_processed
267 |
268 |
269 | class FlexARInferenceSolver:
270 | @classmethod
271 | def get_args_parser(cls):
272 | parser = argparse.ArgumentParser("xllmx Inference", add_help=False)
273 | parser.add_argument("--model_path", type=str)
274 | parser.add_argument("--precision", type=str, choices=["fp16", "bf16", "tf32"], default="bf16")
275 |
276 | return parser
277 |
278 | def __init__(self, model_path, precision, target_size=512):
279 | self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
280 |
281 | self.model = ChameleonForConditionalGeneration.from_pretrained(
282 | model_path,
283 | torch_dtype=self.dtype,
284 | device_map="cuda",
285 | )
286 | self.item_processor = FlexARItemProcessor(target_size=target_size)
287 |
288 | def get_streamer(self):
289 | return TextStreamer(self.item_processor.tokenizer)
290 |
291 | @torch.no_grad()
292 | def generate(
293 | self,
294 | images: Image.Image | str | List[Union[Image.Image, str]],
295 | qas,
296 | max_gen_len,
297 | temperature,
298 | logits_processor=None,
299 | streamer=None,
300 | ):
301 |
302 | conversations = []
303 | for q, a in qas:
304 | conversations.append(
305 | {
306 | "from": "human",
307 | "value": q,
308 | }
309 | )
310 | conversations.append(
311 | {
312 | "from": "gpt",
313 | "value": a,
314 | }
315 | )
316 | item = {"image": images, "conversations": conversations}
317 |
318 | _prompt = self.item_processor.process_item(item)
319 | prompt = []
320 | for value in _prompt:
321 | if isinstance(value, int):
322 | prompt.append(value)
323 | else:
324 | prompt += value["input_ids"]
325 | prompt_len = len(prompt)
326 | prompt = torch.tensor(prompt, dtype=torch.int64, device=self.model.device).unsqueeze(0)
327 |
328 | generation_config = GenerationConfig(
329 | max_new_tokens=max_gen_len,
330 | max_length=self.model.config.max_position_embeddings,
331 | temperature=temperature,
332 | top_k=None,
333 | do_sample=True,
334 | eos_token_id=[8710],
335 | )
336 |
337 | if logits_processor is None:
338 | logits_processor = self.create_logits_processor()
339 |
340 | with torch.cuda.amp.autocast(dtype=self.dtype):
341 | generation_result = self.model.generate(
342 | prompt, generation_config, logits_processor=logits_processor, streamer=streamer
343 | )[0][prompt_len:].tolist()
344 | if len(generation_result) > 0 and generation_result[-1] == 8710:
345 | generation_result = generation_result[:-1]
346 |
347 | return self.decode_ids(generation_result)
348 |
349 | def decode_ids(self, tokens: List[int]):
350 | generated_images = []
351 | generation_result_processed = []
352 | i = 0
353 | while i < len(tokens):
354 | token_id = tokens[i]
355 | if token_id == self.item_processor.token2id(self.item_processor.image_start_token):
356 | cache = []
357 | for j in range(i + 1, len(tokens)):
358 | if tokens[j] != self.item_processor.token2id(self.item_processor.image_end_token):
359 | cache.append(tokens[j])
360 | i = j + 1
361 | else:
362 | image = self.decode_image(cache)
363 | generated_images.append(image)
364 | generation_result_processed.append(self.item_processor.token2id("<|image|>"))
365 | i = j + 1
366 | break
367 | else:
368 | generation_result_processed.append(token_id)
369 | i += 1
370 |
371 | generated = self.item_processor.tokenizer.decode(generation_result_processed)
372 |
373 | return generated, generated_images
374 |
375 | def decode_image(self, tokens: List[int]):
376 | return self.item_processor.decode_image(tokens)
377 |
378 | @staticmethod
379 | def create_image_grid(images, rows, cols):
380 | width, height = images[0].size
381 |
382 | grid_img = Image.new("RGB", (cols * width, rows * height))
383 |
384 | for i, img in enumerate(images):
385 | row = i // cols
386 | col = i % cols
387 | grid_img.paste(img, (col * width, row * height))
388 |
389 | return grid_img
390 |
391 | def create_logits_processor(self, cfg=3.0, image_top_k=2000, text_top_k=10):
392 | logits_processor = LogitsProcessorList()
393 |
394 | cfg_processor = LLMImageStartTriggeredUnbatchedClassifierFreeGuidanceLogitsProcessor(
395 | guidance_scale=cfg,
396 | model=self.model,
397 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token),
398 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token),
399 | image_next_line_token_id=self.item_processor.token2id(self.item_processor.new_line_token),
400 | patch_size=32,
401 | )
402 |
403 | candidate_processor = MultiModalLogitsProcessor(
404 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token),
405 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token),
406 | image_next_line_token_id=self.item_processor.token2id(self.item_processor.new_line_token),
407 | patch_size=32,
408 | voc_size=self.model.config.vocab_size,
409 | )
410 |
411 | topk_processor = InterleavedTopKLogitsWarper(
412 | image_top_k=image_top_k,
413 | text_top_k=text_top_k,
414 | image_start_token_id=self.item_processor.token2id(self.item_processor.image_start_token),
415 | image_end_token_id=self.item_processor.token2id(self.item_processor.image_end_token),
416 | )
417 |
418 | logits_processor.append(cfg_processor)
419 | logits_processor.append(candidate_processor)
420 | logits_processor.append(topk_processor)
421 |
422 | return logits_processor
423 |
424 |
425 | if __name__ == "__main__":
426 | parser = FlexARInferenceSolver.get_args_parser()
427 | args = parser.parse_args()
428 | solver = FlexARInferenceSolver(**vars(args))
429 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_xllmx_chameleon import ChameleonXLLMXConfig
2 | from .modeling_xllmx_chameleon import ChameleonXLLMXForConditionalGeneration
3 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
17 |
18 | _import_structure = {
19 | "configuration_chameleon": ["ChameleonConfig", "ChameleonVQVAEConfig"],
20 | "processing_chameleon": ["ChameleonProcessor"],
21 | }
22 |
23 |
24 | try:
25 | if not is_torch_available():
26 | raise OptionalDependencyNotAvailable()
27 | except OptionalDependencyNotAvailable:
28 | pass
29 | else:
30 | _import_structure["modeling_chameleon"] = [
31 | "ChameleonForConditionalGeneration",
32 | "ChameleonModel",
33 | "ChameleonPreTrainedModel",
34 | "ChameleonVQVAE",
35 | ]
36 |
37 | try:
38 | if not is_vision_available():
39 | raise OptionalDependencyNotAvailable()
40 | except OptionalDependencyNotAvailable:
41 | pass
42 | else:
43 | _import_structure["image_processing_chameleon"] = ["ChameleonImageProcessor"]
44 |
45 |
46 | if TYPE_CHECKING:
47 | from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
48 | from .processing_chameleon import ChameleonProcessor
49 |
50 | try:
51 | if not is_torch_available():
52 | raise OptionalDependencyNotAvailable()
53 | except OptionalDependencyNotAvailable:
54 | pass
55 | else:
56 | from .modeling_chameleon import (
57 | ChameleonForConditionalGeneration,
58 | ChameleonModel,
59 | ChameleonPreTrainedModel,
60 | ChameleonVQVAE,
61 | )
62 |
63 | try:
64 | if not is_vision_available():
65 | raise OptionalDependencyNotAvailable()
66 | except OptionalDependencyNotAvailable:
67 | pass
68 | else:
69 | from .image_processing_chameleon import ChameleonImageProcessor
70 |
71 |
72 | else:
73 | import sys
74 |
75 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
76 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon/configuration_chameleon.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """chameleon model configuration"""
16 |
17 | from typing import List
18 |
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 |
25 | class ChameleonVQVAEConfig(PretrainedConfig):
26 | r"""
27 | This is the configuration class to store the configuration of a [`ChameleonVQModel`]. It is used to instantiate a
28 | `ChameleonVQModel` according to the specified arguments, defining the model architecture.
29 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30 | documentation from [`PretrainedConfig`] for more information. Instantiating a
31 | configuration with the defaults will yield a similar configuration to the VQModel of the
32 | [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B).
33 |
34 | Args:
35 | embed_dim (`int`, *optional*, defaults to 256):
36 | Dimensionality of each embedding vector.
37 | num_embeddings (`int`, *optional*, defaults to 8192):
38 | Number of codebook embeddings.
39 | double_latent (`bool`, *optional*, defaults to `False`):
40 | Whether to use double z channels.
41 | latent_channels (`int`, *optional*, defaults to 256):
42 | Number of channels for the latent space.
43 | resolution (`int`, *optional*, defaults to 512):
44 | Resolution of the input images.
45 | in_channels (`int`, *optional*, defaults to 3):
46 | Number of input channels.
47 | base_channels (`int`, *optional*, defaults to 128):
48 | Base channel count.
49 | channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
50 | Channel multipliers for each resolution.
51 | num_res_blocks (`int`, *optional*, defaults to 2):
52 | Number of residual blocks.
53 | attn_resolutions (`List[int]`, *optional*):
54 | Resolutions to apply attention.
55 | dropout (`float`, *optional*, defaults to 0.0):
56 | Dropout rate.
57 | attn_type (`str`, *optional*, defaults to `"vanilla"`):
58 | Attention type used in VQ-GAN encoder. Can be "vanilla" or None.
59 | initializer_range (`float`, *optional*, defaults to 0.02):
60 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61 | """
62 |
63 | model_type = "chameleon_vqgan"
64 |
65 | def __init__(
66 | self,
67 | embed_dim: int = 256,
68 | num_embeddings: int = 8192,
69 | double_latent: bool = False,
70 | latent_channels: int = 256,
71 | resolution: int = 512,
72 | in_channels: int = 3,
73 | base_channels: int = 128,
74 | channel_multiplier: List[int] = [1, 1, 2, 2, 4],
75 | num_res_blocks: int = 2,
76 | attn_resolutions: List[int] = None,
77 | dropout: float = 0.0,
78 | attn_type: str = "vanilla",
79 | initializer_range=0.02,
80 | **kwargs,
81 | ):
82 | super().__init__(**kwargs)
83 | self.embed_dim = embed_dim
84 | self.num_embeddings = num_embeddings
85 | self.double_latent = double_latent
86 | self.latent_channels = latent_channels
87 | self.resolution = resolution
88 | self.in_channels = in_channels
89 | self.base_channels = base_channels
90 | self.channel_multiplier = channel_multiplier
91 | self.num_res_blocks = num_res_blocks
92 | self.attn_resolutions = attn_resolutions
93 | self.dropout = dropout
94 | self.attn_type = attn_type
95 | self.initializer_range = initializer_range
96 |
97 |
98 | class ChameleonConfig(PretrainedConfig):
99 | r"""
100 | This is the configuration class to store the configuration of a [`ChameleonModel`]. It is used to instantiate a
101 | chameleon model according to the specified arguments, defining the model architecture. Instantiating a
102 | configuration with the defaults will yield a similar configuration to that of the
103 | [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B).
104 |
105 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
106 | documentation from [`PretrainedConfig`] for more information.
107 |
108 |
109 | Args:
110 | vocab_size (`int`, *optional*, defaults to 65536):
111 | Vocabulary size of the chameleon model. Defines the number of different tokens that can be represented by the
112 | `inputs_ids` passed when calling [`ChameleonModel`]; this includes text and image tokens.
113 | hidden_size (`int`, *optional*, defaults to 4096):
114 | Dimension of the hidden representations.
115 | intermediate_size (`int`, *optional*, defaults to 11008):
116 | Dimension of the MLP representations.
117 | num_hidden_layers (`int`, *optional*, defaults to 32):
118 | Number of hidden layers in the Transformer decoder.
119 | num_attention_heads (`int`, *optional*, defaults to 32):
120 | Number of attention heads for each attention layer in the Transformer decoder.
121 | num_key_value_heads (`int`, *optional*, defaults to 32):
122 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
123 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
124 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
125 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
126 | by meanpooling all the original heads within that group. For more details checkout [this
127 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
128 | `num_attention_heads`.
129 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
130 | The non-linear activation function (function or string) in the decoder.
131 | max_position_embeddings (`int`, *optional*, defaults to 4096):
132 | The maximum sequence length that this model might ever be used with. Chameleon supports up to 4096 tokens.
133 | initializer_range (`float`, *optional*, defaults to 0.02):
134 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
135 | rms_norm_eps (`float`, *optional*, defaults to 1e-05):
136 | The epsilon used by the rms normalization layers.
137 | use_cache (`bool`, *optional*, defaults to `True`):
138 | Whether or not the model should return the last key/values attentions (not used by all models). Only
139 | relevant if `config.is_decoder=True`.
140 | pad_token_id (`int`, *optional*):
141 | Padding token id.
142 | bos_token_id (`int`, *optional*, defaults to 1):
143 | Beginning of stream token id.
144 | eos_token_id (`int`, *optional*, defaults to 2):
145 | End of stream token id.
146 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
147 | Whether to tie weight embeddings
148 | rope_theta (`float`, *optional*, defaults to 10000.0):
149 | The base period of the RoPE embeddings.
150 | rope_scaling (`Dict`, *optional*):
151 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
152 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
153 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
154 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
155 | these scaling strategies behave:
156 | https://www.reddit.com/r/Localchameleon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
157 | experimental feature, subject to breaking API changes in future versions.
158 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
159 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
160 | attention_dropout (`float`, *optional*, defaults to 0.0):
161 | The dropout ratio for the attention probabilities.
162 | model_parallel_size (`int`, *optional*, defaults to 1):
163 | Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference
164 | doesn't do reduction in those layers and each rank has its own biases.
165 | swin_norm (`bool`, *optional*, defaults to `False`):
166 | Use Swin Transformer normalization.
167 | vq_config (`dict`, *optional*):
168 | ChameleonVQConfig instance containing the configuration for the VQ-VAE model.
169 | vocabulary_map (`dict`, *optional*):
170 | A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs.
171 | mlp_bias (`bool`, *optional*, defaults to `False`):
172 | Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
173 |
174 |
175 | ```python
176 | >>> from transformers import ChameleonModel, ChameleonConfig
177 |
178 | >>> # Initializing a chameleon chameleon-7b style configuration
179 | >>> configuration = ChameleonConfig()
180 |
181 | >>> # Initializing a model from the chameleon-7b style configuration
182 | >>> model = ChameleonModel(configuration)
183 |
184 | >>> # Accessing the model configuration
185 | >>> configuration = model.config
186 | ```"""
187 |
188 | model_type = "chameleon"
189 | keys_to_ignore_at_inference = ["past_key_values"]
190 |
191 | def __init__(
192 | self,
193 | vocab_size=65536,
194 | hidden_size=4096,
195 | intermediate_size=11008,
196 | num_hidden_layers=32,
197 | num_attention_heads=32,
198 | num_key_value_heads=32,
199 | hidden_act="silu",
200 | max_position_embeddings=4096,
201 | initializer_range=0.02,
202 | rms_norm_eps=1e-05,
203 | use_cache=True,
204 | pad_token_id=None,
205 | bos_token_id=1,
206 | eos_token_id=2,
207 | tie_word_embeddings=False,
208 | rope_theta=10000.0,
209 | rope_scaling=None,
210 | attention_bias=False,
211 | attention_dropout=0.0,
212 | model_parallel_size=1,
213 | swin_norm=False,
214 | vq_config=None,
215 | vocabulary_map=None,
216 | mlp_bias=False,
217 | mask_image_logits=True,
218 | dropout=0.0,
219 | **kwargs,
220 | ):
221 | self.vocab_size = vocab_size
222 | self.max_position_embeddings = max_position_embeddings
223 | self.hidden_size = hidden_size
224 | self.intermediate_size = intermediate_size
225 | self.num_hidden_layers = num_hidden_layers
226 | self.num_attention_heads = num_attention_heads
227 | self.mlp_bias = mlp_bias
228 |
229 | self.num_key_value_heads = num_key_value_heads
230 | self.hidden_act = hidden_act
231 | self.initializer_range = initializer_range
232 | self.rms_norm_eps = rms_norm_eps
233 | self.use_cache = use_cache
234 | self.rope_theta = rope_theta
235 | self.rope_scaling = rope_scaling
236 | self._rope_scaling_validation()
237 | self.attention_bias = attention_bias
238 | self.attention_dropout = attention_dropout
239 | self.model_parallel_size = model_parallel_size
240 | self.swin_norm = swin_norm
241 | self.mask_image_logits = mask_image_logits
242 |
243 | if vq_config is None:
244 | vq_config = {}
245 | logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.")
246 |
247 | self.vq_config = ChameleonVQVAEConfig(**vq_config)
248 |
249 | self.vocabulary_map = vocabulary_map
250 |
251 | self.dropout = dropout
252 |
253 | super().__init__(
254 | pad_token_id=pad_token_id,
255 | bos_token_id=bos_token_id,
256 | eos_token_id=eos_token_id,
257 | tie_word_embeddings=tie_word_embeddings,
258 | **kwargs,
259 | )
260 |
261 | def _rope_scaling_validation(self):
262 | """
263 | Validate the `rope_scaling` configuration.
264 | """
265 | if self.rope_scaling is None:
266 | return
267 |
268 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
269 | raise ValueError(
270 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
271 | f"got {self.rope_scaling}"
272 | )
273 | rope_scaling_type = self.rope_scaling.get("type", None)
274 | rope_scaling_factor = self.rope_scaling.get("factor", None)
275 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
276 | raise ValueError(
277 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
278 | )
279 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
280 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
281 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon/image_processing_chameleon.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Image processor class for Chameleon."""
16 |
17 | from typing import Dict, List, Optional, Union
18 |
19 | import numpy as np
20 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
21 | from transformers.image_transforms import get_resize_output_image_size, resize, to_channel_dimension_format
22 | from transformers.image_utils import (
23 | ChannelDimension,
24 | ImageInput,
25 | PILImageResampling,
26 | infer_channel_dimension_format,
27 | is_scaled_image,
28 | is_valid_image,
29 | to_numpy_array,
30 | valid_images,
31 | validate_kwargs,
32 | validate_preprocess_arguments,
33 | )
34 | from transformers.utils import TensorType, is_vision_available, logging
35 |
36 | logger = logging.get_logger(__name__)
37 |
38 | if is_vision_available():
39 | import PIL
40 |
41 |
42 | def make_batched_images(images) -> List[List[ImageInput]]:
43 | """
44 | Accepts images in list or nested list format, and makes a list of images for preprocessing.
45 |
46 | Args:
47 | images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
48 | The input image.
49 |
50 | Returns:
51 | list: A list of images.
52 | """
53 | if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
54 | return [img for img_list in images for img in img_list]
55 |
56 | elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
57 | return images
58 |
59 | elif is_valid_image(images):
60 | return [images]
61 |
62 | raise ValueError(f"Could not make batched video from {images}")
63 |
64 |
65 | class ChameleonImageProcessor(BaseImageProcessor):
66 | r"""
67 | Constructs a Chameleon image processor.
68 |
69 | Args:
70 | do_resize (`bool`, *optional*, defaults to `True`):
71 | Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
72 | `do_resize` in the `preprocess` method.
73 | size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 512}`):
74 | Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
75 | the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
76 | method.
77 | resample (`PILImageResampling`, *optional*, defaults to 1):
78 | Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
79 | do_center_crop (`bool`, *optional*, defaults to `True`):
80 | Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
81 | `preprocess` method.
82 | crop_size (`Dict[str, int]` *optional*, defaults to {"height": 512, "width": 512}):
83 | Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
84 | method.
85 | do_rescale (`bool`, *optional*, defaults to `True`):
86 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
87 | the `preprocess` method.
88 | rescale_factor (`int` or `float`, *optional*, defaults to 0.0078):
89 | Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
90 | method.
91 | do_normalize (`bool`, *optional*, defaults to `True`):
92 | Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
93 | image_mean (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`):
94 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of
95 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
96 | image_std (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`):
97 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
98 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
99 | Can be overridden by the `image_std` parameter in the `preprocess` method.
100 | do_convert_rgb (`bool`, *optional*, defaults to `True`):
101 | Whether to convert the image to RGB.
102 | """
103 |
104 | model_input_names = ["pixel_values"]
105 |
106 | def __init__(
107 | self,
108 | do_resize: bool = True,
109 | size: Dict[str, int] = None,
110 | resample: PILImageResampling = PIL.Image.LANCZOS,
111 | do_center_crop: bool = True,
112 | crop_size: Dict[str, int] = None,
113 | do_rescale: bool = True,
114 | rescale_factor: Union[int, float] = 0.0078,
115 | do_normalize: bool = True,
116 | image_mean: Optional[Union[float, List[float]]] = None,
117 | image_std: Optional[Union[float, List[float]]] = None,
118 | do_convert_rgb: bool = True,
119 | **kwargs,
120 | ) -> None:
121 | super().__init__(**kwargs)
122 | size = size if size is not None else {"shortest_edge": 512}
123 | size = get_size_dict(size, default_to_square=False)
124 | crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512}
125 | crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
126 |
127 | self.do_resize = do_resize
128 | self.size = size
129 | self.resample = resample
130 | self.do_center_crop = do_center_crop
131 | self.crop_size = crop_size
132 | self.do_rescale = do_rescale
133 | self.rescale_factor = rescale_factor
134 | self.do_normalize = do_normalize
135 | self.image_mean = image_mean if image_mean is not None else [1.0, 1.0, 1.0]
136 | self.image_std = image_std if image_std is not None else [1.0, 1.0, 1.0]
137 | self.do_convert_rgb = do_convert_rgb
138 | self._valid_processor_keys = [
139 | "images",
140 | "do_resize",
141 | "size",
142 | "resample",
143 | "do_center_crop",
144 | "crop_size",
145 | "do_rescale",
146 | "rescale_factor",
147 | "do_normalize",
148 | "image_mean",
149 | "image_std",
150 | "do_convert_rgb",
151 | "return_tensors",
152 | "data_format",
153 | "input_data_format",
154 | ]
155 |
156 | # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
157 | def resize(
158 | self,
159 | image: np.ndarray,
160 | size: Dict[str, int],
161 | resample: PILImageResampling = PILImageResampling.BICUBIC,
162 | data_format: Optional[Union[str, ChannelDimension]] = None,
163 | input_data_format: Optional[Union[str, ChannelDimension]] = None,
164 | **kwargs,
165 | ) -> np.ndarray:
166 | """
167 | Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
168 | resized to keep the input aspect ratio.
169 |
170 | Args:
171 | image (`np.ndarray`):
172 | Image to resize.
173 | size (`Dict[str, int]`):
174 | Size of the output image.
175 | resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
176 | Resampling filter to use when resiizing the image.
177 | data_format (`str` or `ChannelDimension`, *optional*):
178 | The channel dimension format of the image. If not provided, it will be the same as the input image.
179 | input_data_format (`ChannelDimension` or `str`, *optional*):
180 | The channel dimension format of the input image. If not provided, it will be inferred.
181 | """
182 | default_to_square = True
183 | if "shortest_edge" in size:
184 | size = size["shortest_edge"]
185 | default_to_square = False
186 | elif "height" in size and "width" in size:
187 | size = (size["height"], size["width"])
188 | else:
189 | raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
190 |
191 | output_size = get_resize_output_image_size(
192 | image,
193 | size=size,
194 | default_to_square=default_to_square,
195 | input_data_format=input_data_format,
196 | )
197 | return resize(
198 | image,
199 | size=output_size,
200 | resample=resample,
201 | data_format=data_format,
202 | input_data_format=input_data_format,
203 | **kwargs,
204 | )
205 |
206 | def preprocess(
207 | self,
208 | images: ImageInput,
209 | do_resize: bool = None,
210 | size: Dict[str, int] = None,
211 | resample: PILImageResampling = None,
212 | do_center_crop: bool = None,
213 | crop_size: int = None,
214 | do_rescale: bool = None,
215 | rescale_factor: float = None,
216 | do_normalize: bool = None,
217 | image_mean: Optional[Union[float, List[float]]] = None,
218 | image_std: Optional[Union[float, List[float]]] = None,
219 | do_convert_rgb: bool = None,
220 | return_tensors: Optional[Union[str, TensorType]] = None,
221 | data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
222 | input_data_format: Optional[Union[str, ChannelDimension]] = None,
223 | **kwargs,
224 | ) -> PIL.Image.Image:
225 | """
226 | Preprocess an image or batch of images.
227 |
228 | Args:
229 | images (`ImageInput`):
230 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
231 | passing in images with pixel values between 0 and 1, set `do_rescale=False`.
232 | do_resize (`bool`, *optional*, defaults to `self.do_resize`):
233 | Whether to resize the image.
234 | size (`Dict[str, int]`, *optional*, defaults to `self.size`):
235 | Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
236 | the longest edge resized to keep the input aspect ratio.
237 | resample (`int`, *optional*, defaults to `self.resample`):
238 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
239 | has an effect if `do_resize` is set to `True`.
240 | do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
241 | Whether to center crop the image.
242 | crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
243 | Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
244 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
245 | Whether to rescale the image.
246 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
247 | Rescale factor to rescale the image by if `do_rescale` is set to `True`.
248 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
249 | Whether to normalize the image.
250 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
251 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
252 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
253 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
254 | `True`.
255 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
256 | Whether to convert the image to RGB.
257 | return_tensors (`str` or `TensorType`, *optional*):
258 | The type of tensors to return. Can be one of:
259 | - Unset: Return a list of `np.ndarray`.
260 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
261 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
262 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
263 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
264 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
265 | The channel dimension format for the output image. Can be one of:
266 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
267 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
268 | - Unset: Use the channel dimension format of the input image.
269 | input_data_format (`ChannelDimension` or `str`, *optional*):
270 | The channel dimension format for the input image. If unset, the channel dimension format is inferred
271 | from the input image. Can be one of:
272 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
273 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
274 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
275 | """
276 | do_resize = do_resize if do_resize is not None else self.do_resize
277 | size = size if size is not None else self.size
278 | size = get_size_dict(size, param_name="size", default_to_square=False)
279 | resample = resample if resample is not None else self.resample
280 | do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
281 | crop_size = crop_size if crop_size is not None else self.crop_size
282 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
283 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale
284 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
285 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize
286 | image_mean = image_mean if image_mean is not None else self.image_mean
287 | image_std = image_std if image_std is not None else self.image_std
288 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
289 |
290 | validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
291 |
292 | images = make_batched_images(images)
293 |
294 | if not valid_images(images):
295 | raise ValueError(
296 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
297 | "torch.Tensor, tf.Tensor or jax.ndarray."
298 | )
299 |
300 | validate_preprocess_arguments(
301 | do_rescale=do_rescale,
302 | rescale_factor=rescale_factor,
303 | do_normalize=do_normalize,
304 | image_mean=image_mean,
305 | image_std=image_std,
306 | do_center_crop=do_center_crop,
307 | crop_size=crop_size,
308 | do_resize=do_resize,
309 | size=size,
310 | resample=resample,
311 | )
312 |
313 | if do_convert_rgb:
314 | images = [self.blend_rgba(image) for image in images]
315 |
316 | # All transformations expect numpy arrays.
317 | images = [to_numpy_array(image) for image in images]
318 |
319 | if is_scaled_image(images[0]) and do_rescale:
320 | logger.warning_once(
321 | "It looks like you are trying to rescale already rescaled images. If the input"
322 | " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
323 | )
324 |
325 | if input_data_format is None:
326 | # We assume that all images have the same channel dimension format.
327 | input_data_format = infer_channel_dimension_format(images[0])
328 |
329 | if do_resize:
330 | images = [
331 | self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
332 | for image in images
333 | ]
334 |
335 | if do_center_crop:
336 | images = [
337 | self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
338 | ]
339 |
340 | if do_rescale:
341 | images = [
342 | self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images
343 | ]
344 |
345 | if do_normalize:
346 | images = [
347 | self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
348 | for image in images
349 | ]
350 |
351 | images = [
352 | to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
353 | ]
354 |
355 | data = {"pixel_values": images}
356 | return BatchFeature(data=data, tensor_type=return_tensors)
357 |
358 | def blend_rgba(self, image: ImageInput) -> ImageInput:
359 | """
360 | Convert image to RGB by blending the transparency layer if it's in RGBA format.
361 | If image is not `PIL.Image`, it si simply returned without modifications.
362 |
363 | Args:
364 | image (`ImageInput`):
365 | Image to convert.
366 | """
367 |
368 | if not isinstance(image, PIL.Image.Image):
369 | return image
370 | elif image.mode == "RGB":
371 | return image
372 |
373 | img_rgba = np.array(image.convert("RGBA"))
374 |
375 | # If there is no transparency layer, simple convert and return.
376 | if not (img_rgba[:, :, 3] < 255).any():
377 | return image.convert("RGB")
378 |
379 | # There is a transparency layer, blend it with a white background.
380 | # Calculate the alpha proportion for blending.
381 | alpha = img_rgba[:, :, 3] / 255.0
382 | img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3]
383 | return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB")
384 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon/processing_chameleon.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Processor class for Chameleon.
17 | """
18 |
19 | from typing import List, Optional, Union
20 |
21 | from transformers.feature_extraction_utils import BatchFeature
22 | from transformers.image_utils import ImageInput
23 | from transformers.processing_utils import ProcessorMixin
24 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25 | from transformers.utils import TensorType
26 |
27 |
28 | class ChameleonProcessor(ProcessorMixin):
29 | r"""
30 | Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
31 | processor.
32 |
33 | [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
34 | See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
35 |
36 | Args:
37 | image_processor ([`ChameleonImageProcessor`]):
38 | The image processor is a required input.
39 | tokenizer ([`LlamaTokenizerFast`]):
40 | The tokenizer is a required input.
41 | image_seq_length (`int`, *optional*, defaults to 1024):
42 | Sequence length of one image embedding.
43 | image_token (`str`, *optional*, defaults to `""`):
44 | The special token used to indicate image in the text.
45 | """
46 |
47 | attributes = ["image_processor", "tokenizer"]
48 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
49 | image_processor_class = "ChameleonImageProcessor"
50 |
51 | def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""):
52 | self.image_seq_length = image_seq_length
53 | self.image_token = image_token
54 | self.image_start_token = "" # fixed tokens for start and end, so can hardcode
55 | self.image_end_token = ""
56 | super().__init__(image_processor, tokenizer)
57 |
58 | def __call__(
59 | self,
60 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
61 | images: ImageInput = None,
62 | padding: Union[bool, str, PaddingStrategy] = False,
63 | truncation: Union[bool, str, TruncationStrategy] = None,
64 | max_length: int = None,
65 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
66 | return_for_text_completion: bool = False,
67 | ) -> BatchFeature:
68 | """
69 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
70 | and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
71 | the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
72 | CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
73 | of the above two methods for more information.
74 |
75 | Args:
76 | text (`str`, `List[str]`, `List[List[str]]`):
77 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
78 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
79 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
80 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
81 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
82 | tensor. Both channels-first and channels-last formats are supported.
83 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
84 | Select a strategy to pad the returned sequences (according to the model's padding side and padding
85 | index) among:
86 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
87 | sequence if provided).
88 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
89 | acceptable input length for the model if that argument is not provided.
90 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
91 | lengths).
92 | max_length (`int`, *optional*):
93 | Maximum length of the returned list and optionally padding length (see above).
94 | truncation (`bool`, *optional*):
95 | Activates truncation to cut input sequences longer than `max_length` to `max_length`.
96 | return_tensors (`str` or [`~utils.TensorType`], *optional*):
97 | If set, will return tensors of a particular framework. Acceptable values are:
98 |
99 | - `'tf'`: Return TensorFlow `tf.constant` objects.
100 | - `'pt'`: Return PyTorch `torch.Tensor` objects.
101 | - `'np'`: Return NumPy `np.ndarray` objects.
102 | - `'jax'`: Return JAX `jnp.ndarray` objects.
103 |
104 | Returns:
105 | [`BatchFeature`]: A [`BatchFeature`] with the following fields:
106 |
107 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
108 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
109 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
110 | `None`).
111 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
112 | """
113 | if isinstance(text, str):
114 | text = [text]
115 | elif not isinstance(text, list) and not isinstance(text[0], str):
116 | raise TypeError("Invalid input text. Please provide a string, or a list of strings")
117 |
118 | # Replace the image token with the expanded image token sequence
119 | prompt_strings = []
120 | one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
121 | for sample in text:
122 | sample = sample.replace(self.image_token, one_img_tokens)
123 | if not return_for_text_completion:
124 | sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
125 | prompt_strings.append(sample)
126 |
127 | data = self.tokenizer(
128 | prompt_strings,
129 | return_tensors=return_tensors,
130 | padding=padding,
131 | truncation=truncation,
132 | max_length=max_length,
133 | )
134 |
135 | if images is not None:
136 | pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
137 | data["pixel_values"] = pixel_values
138 |
139 | return BatchFeature(data=data, tensor_type=return_tensors)
140 |
141 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
142 | def batch_decode(self, *args, **kwargs):
143 | """
144 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
145 | refer to the docstring of this method for more information.
146 | """
147 | return self.tokenizer.batch_decode(*args, **kwargs)
148 |
149 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
150 | def decode(self, *args, **kwargs):
151 | """
152 | This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
153 | the docstring of this method for more information.
154 | """
155 | return self.tokenizer.decode(*args, **kwargs)
156 |
157 | @property
158 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
159 | def model_input_names(self):
160 | tokenizer_input_names = self.tokenizer.model_input_names
161 | image_processor_input_names = self.image_processor.model_input_names
162 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
163 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon_vae_ori/__init__.py:
--------------------------------------------------------------------------------
1 | from .image_tokenizer import ImageTokenizer
2 | from .vocab import VocabInfo, VocabTranslation
3 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon_vae_ori/image_tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates
2 | #
3 | # This source code is licensed under the Chameleon License found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import PIL
7 | from PIL import Image
8 | import numpy as np
9 | import torch
10 | import yaml
11 |
12 | from .vqgan import VQModel
13 |
14 |
15 | class ImageTokenizer:
16 | def __init__(
17 | self,
18 | cfg_path: str,
19 | ckpt_path: str,
20 | device: str | torch.device | None = None,
21 | ):
22 | with open(cfg_path) as f:
23 | config = yaml.safe_load(f)
24 |
25 | params = config["model"]["params"]
26 | if "lossconfig" in params:
27 | del params["lossconfig"]
28 | params["ckpt_path"] = ckpt_path
29 |
30 | self._vq_model = VQModel(**params)
31 | self._vq_model.eval()
32 |
33 | if device is None:
34 | devices = {p.device for p in self._vq_model.parameters()}
35 | assert len(devices) == 1
36 | device = devices.pop()
37 | else:
38 | self._vq_model.to(device)
39 | self._device = device
40 |
41 | dtypes = {p.dtype for p in self._vq_model.parameters()}
42 | assert len(dtypes) == 1
43 | self._dtype = dtypes.pop()
44 |
45 | def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
46 | # Check if it's already in RGB format.
47 | if img.mode == "RGB":
48 | return img
49 |
50 | vals_rgba = np.array(img.convert("RGBA"))
51 |
52 | # If there is no transparency layer, simple convert and return.
53 | if not (vals_rgba[:, :, 3] < 255).any():
54 | return img.convert("RGB")
55 |
56 | # There is a transparency layer, blend it with a white background.
57 |
58 | # Calculate the alpha proportion for blending.
59 | alpha = vals_rgba[:, :, 3] / 255.0
60 | # Blend with white background.
61 | vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * vals_rgba[:, :, :3]
62 | return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
63 |
64 | # def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
65 | # # Resize with aspect ratio preservation.
66 | # s = min(img.size)
67 | # scale = target_image_size / s
68 | # new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
69 | # img = img.resize(new_size, PIL.Image.LANCZOS)
70 | #
71 | # # Center crop.
72 | # x0 = (img.width - target_image_size) // 2
73 | # y0 = (img.height - target_image_size) // 2
74 | # img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
75 | #
76 | # # Convert to tensor.
77 | # np_img = np.array(img) / 255.0 # Normalize to [0, 1]
78 | # np_img = np_img * 2 - 1 # Scale to [-1, 1]
79 | # tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() # (Channels, Height, Width) format.
80 | #
81 | # # Add batch dimension.
82 | # return tensor_img.unsqueeze(0)
83 |
84 | def img_tokens_from_pil(self, img: PIL.Image) -> list[int]:
85 | img = self._whiten_transparency(img)
86 | # Convert to tensor.
87 | np_img = np.array(img) / 255.0 # Normalize to [0, 1]
88 | np_img = np_img * 2 - 1 # Scale to [-1, 1]
89 | img = torch.from_numpy(np_img).permute(2, 0, 1).to(self._vq_model.encoder.conv_in.weight)
90 | img = img.unsqueeze(0)
91 |
92 | _, _, [_, _, img_toks] = self._vq_model.encode(img)
93 | return img_toks
94 |
95 | def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
96 | # Ensure detachment and move tensor to CPU.
97 | detached_chw_tensor = chw_tensor.detach().cpu()
98 |
99 | # Normalize tensor to [0, 1] range from [-1, 1] range.
100 | normalized_chw_tensor = (torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0) / 2.0
101 |
102 | # Permute CHW tensor to HWC format and convert to NumPy array.
103 | hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
104 |
105 | # Convert to an 8-bit unsigned integer format.
106 | image_array_uint8 = (hwc_array * 255).astype(np.uint8)
107 |
108 | # Convert NumPy array to PIL Image.
109 | pil_image = Image.fromarray(image_array_uint8)
110 |
111 | # Convert image to RGB if it is not already.
112 | if pil_image.mode != "RGB":
113 | pil_image = pil_image.convert("RGB")
114 |
115 | return pil_image
116 |
117 | def pil_from_img_toks(self, tokens: torch.Tensor, h_latent_dim=32, w_latent_dim=32) -> PIL.Image:
118 | emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
119 | codebook_entry = self._vq_model.quantize.get_codebook_entry(tokens, (1, h_latent_dim, w_latent_dim, emb_dim))
120 | pixels = self._vq_model.decode(codebook_entry)
121 | return self._pil_from_chw_tensor(pixels[0])
122 |
123 | def latent_embedding_from_pil(self, img: PIL.Image):
124 | img = self._whiten_transparency(img)
125 |
126 | # Convert to tensor.
127 | np_img = np.array(img) / 255.0 # Normalize to [0, 1]
128 | np_img = np_img * 2 - 1 # Scale to [-1, 1]
129 | img = torch.from_numpy(np_img).permute(2, 0, 1) # (Channels, Height, Width) format.
130 | img = img.unsqueeze(0).to(self._vq_model.encoder.conv_in.weight)
131 | latent_embedding, _, _ = self._vq_model.encode(img)
132 | return latent_embedding
133 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/chameleon_vae_ori/vocab.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Chameleon License found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from functools import cached_property
7 |
8 | import torch
9 |
10 |
11 | class VocabInfo:
12 | def __init__(self, vocab_map: dict[str, int]):
13 | self.name2val = vocab_map
14 |
15 | self.bos_id = vocab_map.get("")
16 | self.eos_id = vocab_map.get("")
17 | self.boi_id = vocab_map.get("")
18 | self.eoi_id = vocab_map.get("")
19 | self.pad_id = vocab_map.get("")
20 | self.eot_id = vocab_map.get("")
21 |
22 | @property
23 | def begin_sequence(self) -> int:
24 | return self.bos_id
25 |
26 | @property
27 | def end_sequence(self) -> int:
28 | return self.eos_id
29 |
30 | @property
31 | def begin_image(self) -> int:
32 | return self.boi_id
33 |
34 | @property
35 | def end_image(self) -> int:
36 | return self.eoi_id
37 |
38 | @property
39 | def padding(self) -> int:
40 | return self.pad_id
41 |
42 | @property
43 | def end_turn(self) -> int:
44 | return self.eot_id
45 |
46 | @cached_property
47 | def val2name(self) -> dict[int, str]:
48 | return {v: k for k, v in self.name2val.items()}
49 |
50 | @cached_property
51 | def all_tokens(self) -> list[int]:
52 | return sorted(self.name2val.values())
53 |
54 | @cached_property
55 | def image_tokens(self) -> list[int]:
56 | return sorted([val for name, val in self.name2val.items() if name.startswith("IMGIMG")])
57 |
58 | @cached_property
59 | def special_tokens(self) -> list[int]:
60 | return sorted([val for name, val in self.name2val.items() if name.startswith("<") and name != "<"])
61 |
62 | @cached_property
63 | def text_tokens(self) -> list[int]:
64 | return sorted(set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens))
65 |
66 |
67 | class VocabTranslation:
68 | def __init__(self, vocab_info: VocabInfo, device: str | None = None):
69 | self._vocab = vocab_info
70 | self._device = device
71 |
72 | @cached_property
73 | def bpe2img(self) -> dict[int, int]:
74 | img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
75 |
76 | def remap(old_name: str) -> str:
77 | return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
78 |
79 | return {tok: int(remap(self._vocab.val2name[tok])) for tok in self._vocab.image_tokens}
80 |
81 | @cached_property
82 | def img2bpe(self) -> dict[int, int]:
83 | return {v: k for k, v in self.bpe2img.items()}
84 |
85 | @cached_property
86 | def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
87 | sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device)
88 | sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device)
89 | return sorted_bpe, sorted_img
90 |
91 | @cached_property
92 | def img2bpe_mapping_tensor(self) -> torch.LongTensor:
93 | mapping = torch.zeros(
94 | max(self.img2bpe.keys()) + 1,
95 | dtype=torch.int,
96 | device=self._device,
97 | )
98 | for k, v in self.img2bpe.items():
99 | mapping[k] = v
100 | return mapping
101 |
102 | def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor:
103 | bpe_tok, img_tok = self.bpe2img_search_tensors
104 | return img_tok[torch.searchsorted(bpe_tok, bpe_batch)]
105 |
106 | def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor:
107 | return self.img2bpe_mapping_tensor[img_batch]
108 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/configuration_xllmx_chameleon.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List
3 |
4 | from .chameleon import ChameleonConfig
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class ChameleonXLLMXConfig(ChameleonConfig):
10 |
11 | def __init__(
12 | self,
13 | z_loss_weight: float = 0.0,
14 | **kwargs,
15 | ):
16 | self.z_loss_weight = z_loss_weight
17 | super().__init__(
18 | **kwargs,
19 | )
20 |
--------------------------------------------------------------------------------
/lumina_mgpt/model/modeling_xllmx_chameleon.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import math
4 | from typing import List
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from .chameleon import ChameleonForConditionalGeneration
10 | from .configuration_xllmx_chameleon import ChameleonXLLMXConfig
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))
15 |
16 |
17 | __all__ = ["ChameleonXLLMXForConditionalGeneration"]
18 |
19 |
20 | class ChameleonXLLMXForConditionalGeneration(ChameleonForConditionalGeneration):
21 | config_class = ChameleonXLLMXConfig
22 |
23 | def __init__(self, config):
24 | super().__init__(config)
25 |
26 | def forward(self, input_ids=None, labels=None, training=True, **kwargs):
27 |
28 | max_tokens = max([len(_) for _ in input_ids])
29 | max_tokens = min(max_tokens, self.config.max_position_embeddings)
30 | input_ids = [_[:max_tokens] for _ in input_ids]
31 | labels = [_[:max_tokens] for _ in labels]
32 |
33 | input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]
34 | input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device)
35 |
36 | labels = [label + [-100] * (max_tokens - len(label)) for label in labels]
37 | labels = torch.tensor(labels, dtype=torch.int64, device=self.device)
38 |
39 | # explicit use_cache=False for the following
40 | # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
41 | result = ChameleonForConditionalGeneration.forward(
42 | self, input_ids=input_ids, labels=labels, use_cache=False, **kwargs
43 | )
44 |
45 | c_loss = result[0]
46 |
47 | additional_loss_dict = {}
48 | if self.config.z_loss_weight > 0:
49 | logits: torch.Tensor = result[1]
50 | shift_logits = logits[..., :-1, :].contiguous()
51 | shift_labels = labels[..., 1:].contiguous()
52 | valid_mask = shift_labels >= 0
53 | z_loss = torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
54 | additional_loss_dict["z_loss"] = (z_loss, self.config.z_loss_weight)
55 | return c_loss, additional_loss_dict
56 |
57 | def get_fsdp_wrap_module_list(self) -> List:
58 | modules = [*list(self.model.layers), self.lm_head, self.model.embed_tokens]
59 | if hasattr(self.model, "vqmodel"): # may be deleted
60 | modules.append(self.model.vqmodel)
61 | return modules
62 |
63 | def get_checkpointing_wrap_module_list(self) -> List:
64 | modules = [
65 | *list(self.model.layers),
66 | ]
67 | return modules
68 |
--------------------------------------------------------------------------------
/lumina_mgpt/pre_tokenize/concat_record.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import json
3 | import os
4 | import re
5 | import warnings
6 |
7 |
8 | def find_sub_records(directory: str):
9 | pattern = re.compile(r"\d+-of-\d+-record\.json(l)?")
10 |
11 | sub_record_files = [f for f in os.listdir(directory) if pattern.match(f)]
12 | sorted_files = sorted(sub_record_files, key=lambda filename: int(filename.split("-of")[0]))
13 | return sorted_files
14 |
15 |
16 | if __name__ == "__main__":
17 | parser = ArgumentParser()
18 | parser.add_argument(
19 | "--sub_record_dir",
20 | type=str,
21 | default=None,
22 | )
23 | parser.add_argument(
24 | "--save_path",
25 | type=str,
26 | default=None,
27 | )
28 | args = parser.parse_args()
29 |
30 | l_sub_records = find_sub_records(args.sub_record_dir)
31 |
32 | print(f"find {len(l_sub_records)} sub-records in {args.sub_record_dir}")
33 | print(str(l_sub_records) + "\n\n")
34 |
35 | complete_record = []
36 | for sub_record in l_sub_records:
37 | with open(os.path.join(args.sub_record_dir, sub_record)) as f:
38 | lines = f.readlines()
39 | for i, l in enumerate(lines):
40 | try:
41 | l_item = json.loads(l)
42 | complete_record.append(l_item)
43 | except:
44 | if i == len(lines) - 1:
45 | print(f"{sub_record} seems still writing, skip last incomplete record")
46 | else:
47 | warnings.warn(f"read line failed: {l}")
48 |
49 | with open(args.save_path, "w") as f:
50 | json.dump(complete_record, f)
51 |
--------------------------------------------------------------------------------
/lumina_mgpt/pre_tokenize/pre_tokenize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(__file__).rsplit("/", 2)[0])
5 |
6 | from argparse import ArgumentParser
7 | import json
8 | import math
9 | import pickle
10 |
11 | from data.convertsation import Conversation
12 | from data.item_processor import FlexARItemProcessor
13 |
14 |
15 | class ItemProcessor(FlexARItemProcessor):
16 | def __init__(
17 | self,
18 | tokenizer="Alpha-VLLM/Lumina-mGPT-7B-768",
19 | conv_template=Conversation,
20 | target_size=512,
21 | ):
22 | super().__init__(tokenizer, conv_template, target_size)
23 | print(self.crop_size_list)
24 |
25 | def process_item(self, raw_item, training_mode=False, out_flatten=True):
26 |
27 | # Add custom codes here to convert raw_item to the standard format
28 | # The standard format contains the "conversations" and "image" keys
29 |
30 | # ********* Add your custom codes here *******
31 |
32 | # ********* Add your custom codes here *******
33 |
34 | item = {
35 | "conversations": raw_item["conversations"],
36 | "image": raw_item["image"],
37 | }
38 |
39 | return super(ItemProcessor, self).process_item(item, training_mode, out_flatten)
40 |
41 |
42 | if __name__ == "__main__":
43 |
44 | parser = ArgumentParser()
45 | parser.add_argument(
46 | "--splits",
47 | type=int,
48 | default=8,
49 | )
50 | parser.add_argument(
51 | "--rank",
52 | type=int,
53 | default=0,
54 | )
55 | parser.add_argument(
56 | "--in_filename",
57 | type=str,
58 | )
59 | parser.add_argument(
60 | "--out_dir",
61 | type=str,
62 | )
63 | parser.add_argument("--target_size", type=int, default=512)
64 | args = parser.parse_args()
65 |
66 | item_processor = ItemProcessor(target_size=args.target_size)
67 |
68 | with open(args.in_filename) as f:
69 | ori_contents = json.load(f)
70 |
71 | num = len(ori_contents)
72 |
73 | splits = args.splits
74 | rank = args.rank
75 | output_dir = args.out_dir
76 | save_dir = os.path.join(output_dir, "files")
77 | os.makedirs(save_dir, exist_ok=True)
78 |
79 | num_per_rank = math.ceil(num / splits)
80 |
81 | try:
82 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "r") as f:
83 | start_idx = int(f.read()) + 1
84 | print(f"resume from {start_idx}")
85 | except:
86 | start_idx = num_per_rank * rank
87 | print(f"start from {start_idx}")
88 |
89 | end_idx = min(num_per_rank * (rank + 1), len(ori_contents))
90 | for i in range(start_idx, end_idx):
91 | if i % 10 == 0:
92 | print(f"{i}/{end_idx}")
93 |
94 | record = None
95 | pkl_path = os.path.join(save_dir, f"{i}.pkl")
96 | try:
97 | tokens, labels = item_processor.process_item(ori_contents[i], training_mode=True)
98 | new_item = {"token": tokens, "label": labels, "id": i}
99 | with open(pkl_path, "wb") as f:
100 | pickle.dump(new_item, f)
101 |
102 | record = {"file": pkl_path, "len": len(tokens), "id": i}
103 |
104 | except Exception as e:
105 | from traceback import format_exc
106 |
107 | print(f"item {i} error: \n{ori_contents[i]}")
108 | print(format_exc())
109 |
110 | if record is not None:
111 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-record.jsonl"), "a") as f:
112 | record_str = json.dumps(record) + "\n"
113 | f.write(record_str)
114 |
115 | with open(os.path.join(output_dir, f"{rank}-of-{splits}-progress.txt"), "w") as f:
116 | if i == end_idx - 1:
117 | f.write("finished")
118 | else:
119 | f.write(f"{i}")
120 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.3.0
2 | torchvision==0.18.0
3 | torchaudio==2.3.0
4 | pandas
5 | tensorboard
6 | fairscale
7 | sentencepiece
8 | gradio==4.19.0
9 | packaging
10 | transformers>=4.43.3
11 | pyyaml
12 | pathlib
13 | Ninja
14 | bitsandbytes
15 | httpx[socks]
16 | einops
17 | regex
18 | h5py
19 | accelerate
20 | pre-commit
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="xllmx",
8 | version="0.0.1",
9 | author="Alpha-VLLM",
10 | description="An Open-source Toolkit for LLM-centered Any2Any Generation",
11 | long_description=long_description,
12 | long_description_content_type="text/markdown",
13 | url="https://github.com/Alpha-VLLM/Lumina-mGPT",
14 | packages=["xllmx"],
15 | include_package_data=True,
16 | )
17 |
--------------------------------------------------------------------------------
/xllmx/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/__init__.py
--------------------------------------------------------------------------------
/xllmx/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/data/__init__.py
--------------------------------------------------------------------------------
/xllmx/data/conversation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/data/conversation/__init__.py
--------------------------------------------------------------------------------
/xllmx/data/conversation/template.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class ConversationBase:
5 | roles = ["Human", "Assistant"]
6 |
7 | def __init__(self, messages=None):
8 | self.messages = messages or []
9 |
10 | def process(self):
11 | raise NotImplementedError
12 |
13 | def get_prompt(self):
14 | return self.process()["conv"]
15 |
16 | def append_message(self, role, message):
17 | self.messages.append([role, message])
18 |
19 | def copy(self):
20 | return ConversationBase(
21 | messages=[[x, y] for x, y in self.messages],
22 | )
23 |
24 | def load_qas(self, qas: List[List[str]]):
25 | self.messages = []
26 | for q, a in qas:
27 | self.append_message(self.roles[0], q)
28 | self.append_message(self.roles[1], a)
29 |
--------------------------------------------------------------------------------
/xllmx/data/data_reader.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import logging
3 | import time
4 | from typing import Union
5 |
6 | from PIL import Image
7 |
8 | Image.MAX_IMAGE_PIXELS = None
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def read_general(path) -> Union[str, BytesIO]:
13 | if "s3://" in path:
14 | init_ceph_client_if_needed()
15 | file_bytes = BytesIO(client.get(path))
16 | return file_bytes
17 | else:
18 | return path
19 |
20 |
21 | def init_ceph_client_if_needed():
22 | global client
23 | if client is None:
24 | logger.info(f"initializing ceph client ...")
25 | st = time.time()
26 | from petrel_client.client import Client # noqa
27 |
28 | client = Client("/path/to/petreloss.conf")
29 | ed = time.time()
30 | logger.info(f"initialize client cost {ed - st:.2f} s")
31 |
32 |
33 | client = None
34 |
--------------------------------------------------------------------------------
/xllmx/data/dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import logging
4 | import os
5 | from pathlib import Path
6 | import pickle
7 | from time import sleep
8 | import traceback
9 | import warnings
10 |
11 | import h5py
12 | import torch
13 | import torch.distributed as dist
14 | from torch.utils.data import Dataset
15 | import yaml
16 |
17 | from .item_processor import ItemProcessorBase
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class FinetuneConversationDataset(Dataset):
23 | def __init__(self, config_path, item_processor: ItemProcessorBase, cache_on_disk=False):
24 |
25 | self.item_processor = item_processor
26 |
27 | logger.info(f"read dataset config from {config_path}")
28 | with open(config_path, "r") as f:
29 | self.config = yaml.load(f, Loader=yaml.FullLoader)
30 | logger.info("DATASET CONFIG:")
31 | logger.info(self.config)
32 |
33 | self.cache_on_disk = cache_on_disk
34 | if self.cache_on_disk:
35 | cache_dir = self._get_cache_dir(config_path)
36 | if dist.get_rank() == 0:
37 | self._collect_annotations_and_save_to_cache(cache_dir)
38 | dist.barrier()
39 | self.meta_collection, self.annotations_collection = self._load_annotations_from_cache(cache_dir)
40 | else:
41 | cache_dir = None
42 | self.meta_collection, self.annotations_collection = self._collect_annotations()
43 |
44 | def __len__(self):
45 | return sum([_["len"] for _ in self.meta_collection])
46 |
47 | def _collect_annotations(self):
48 | meta_collection = []
49 | annotations_collection = []
50 |
51 | for meta in self.config["META"]:
52 | meta, annotations = self._load_meta(meta)
53 | meta_collection.append(meta)
54 | annotations_collection.append(annotations)
55 |
56 | return meta_collection, annotations_collection
57 |
58 | def _load_meta(self, meta):
59 | if "type" not in meta:
60 | meta["type"] = "default"
61 |
62 | meta_path, meta_type = meta["path"], meta["type"]
63 | meta_ext = os.path.splitext(meta_path)[-1]
64 | if meta_ext == ".json":
65 | with open(meta_path) as f:
66 | annotations = json.load(f)
67 | elif meta_ext == ".jsonl":
68 | annotations = []
69 | with open(meta_path) as f:
70 | for i, line in enumerate(f):
71 | try:
72 | annotations.append(json.loads(line))
73 | except json.decoder.JSONDecodeError as e:
74 | logger.error(f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}")
75 | raise e
76 | elif meta_ext == ".pkl":
77 | with open(meta_path, "rb") as f:
78 | annotations = pickle.load(f)
79 | assert isinstance(annotations, list)
80 | elif meta_ext == ".pth":
81 | annotations = torch.load(meta_path)
82 | assert isinstance(annotations, list)
83 | else:
84 | raise NotImplementedError(
85 | f'Unknown meta file extension: "{meta_ext}". '
86 | f"Currently, .json, .jsonl are supported. "
87 | "If you are using a supported format, please set the file extension so that the proper parsing "
88 | "routine can be called."
89 | )
90 | logger.info(f"{meta_path}, type{meta_type}: len {len(annotations)}")
91 |
92 | meta["len"] = len(annotations)
93 |
94 | meta["item_len_list"] = [self.item_processor.predict_item_token_length(_) for _ in annotations]
95 |
96 | return meta, annotations
97 |
98 | def _collect_annotations_and_save_to_cache(self, cache_dir):
99 | if (Path(cache_dir) / "data.h5").exists() and (Path(cache_dir) / "ready").exists():
100 | # off-the-shelf annotation cache exists
101 | warnings.warn(
102 | f"Use existing h5 data cache: {Path(cache_dir)}\n"
103 | f"Note: if the actual data defined by the data config has changed since your last run, "
104 | f"please delete the cache manually and re-run this experiment, or the data actually used "
105 | f"will not be updated"
106 | )
107 | return
108 |
109 | Path(cache_dir).mkdir(parents=True, exist_ok=True)
110 | meta_collection, annotations_collection = self._collect_annotations()
111 |
112 | # when cache on disk, rank0 saves items to an h5 file
113 | logger.info(f"start to build data cache to: {Path(cache_dir)}")
114 | with h5py.File(Path(cache_dir) / "data.h5", "w") as file:
115 | dt = h5py.vlen_dtype(str)
116 | for i, annotations in enumerate(annotations_collection):
117 | serialized_ann = [json.dumps(_) for _ in annotations]
118 | h5_ann = file.create_dataset(f"ann{i}", (len(serialized_ann),), dtype=dt)
119 | h5_ann[:] = serialized_ann
120 |
121 | file.create_dataset("meta_collection", data=json.dumps(meta_collection))
122 | with open(Path(cache_dir) / "ready", "w") as f:
123 | f.write("ready")
124 | logger.info(f"data cache built")
125 |
126 | @staticmethod
127 | def _get_cache_dir(config_path):
128 | config_identifier = config_path
129 | disallowed_chars = ["/", "\\", ".", "?", "!"]
130 | for _ in disallowed_chars:
131 | config_identifier = config_identifier.replace(_, "-")
132 | cache_dir = f"./xllmx_data_cache/{config_identifier}"
133 | return cache_dir
134 |
135 | @staticmethod
136 | def _load_annotations_from_cache(cache_dir):
137 | while not (Path(cache_dir) / "ready").exists():
138 | # cache has not yet been completed by rank 0
139 | assert dist.get_rank() != 0
140 | sleep(1)
141 | cache_file = h5py.File(Path(cache_dir) / "data.h5", "r")
142 | meta_collection = json.loads(cache_file["meta_collection"].asstr()[()])
143 | annotations_collection = [cache_file[f"ann{i}"] for i in range(len(meta_collection))]
144 | return meta_collection, annotations_collection
145 |
146 | def get_item_func(self, meta_idx, idx_in_meta):
147 | data_item = self.annotations_collection[meta_idx][idx_in_meta]
148 | if self.cache_on_disk:
149 | data_item = json.loads(data_item)
150 | else:
151 | data_item = copy.deepcopy(data_item)
152 |
153 | return self.item_processor.process_item(data_item, training_mode=True)
154 |
155 | def tie_index_to_meta(self, idx: int):
156 | # Initialize the starting index
157 | start_idx = 0
158 |
159 | # Iterate through the list of dictionaries
160 | for i, meta in enumerate(self.meta_collection):
161 | # Calculate the ending index for the current collection
162 | end_idx = start_idx + meta["len"]
163 |
164 | # Check if the given index falls within the current collection
165 | if start_idx <= idx < end_idx:
166 | # Calculate the new index within the current collection
167 | new_index = idx - start_idx
168 | return i, new_index
169 |
170 | # Update the starting index for the next collection
171 | start_idx = end_idx
172 |
173 | # If the index is out of range of all collections, raise an error
174 | raise IndexError("Index out of range")
175 |
176 | def __getitem__(self, index):
177 | meta_idx, idx_in_meta = self.tie_index_to_meta(index)
178 |
179 | try:
180 | return self.get_item_func(meta_idx, idx_in_meta)
181 | except Exception as e:
182 | logger.info(
183 | f"Item {index} errored, annotation:\n"
184 | f"{self.annotations_collection[meta_idx][idx_in_meta]}\n"
185 | f"Error:\n"
186 | f"{traceback.format_exc()}"
187 | )
188 | if idx_in_meta != 0:
189 | return self[index - 1]
190 | else:
191 | return self[index + self.meta_collection[meta_idx]["len"] - 1]
192 |
--------------------------------------------------------------------------------
/xllmx/data/item_processor.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import copy
3 | import logging
4 | from typing import Any, Callable, Dict, List, Tuple, Union
5 |
6 | from xllmx.model.tokenizer import Tokenizer
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class LabelAllZeroError(Exception):
12 | def __init__(self, message=None):
13 | self.message = message
14 |
15 | def __str__(self):
16 | return f"LabelAllZeroError: {self.message}"
17 |
18 |
19 | class ItemProcessorBase(ABC):
20 | @abstractmethod
21 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]:
22 | raise NotImplementedError
23 |
24 | def predict_item_token_length(self, data_item: dict) -> int:
25 | """
26 | estimate the token length of the data item for gathering items of similar lengths into a batch
27 | """
28 | return 1
29 |
30 |
31 | class MMConvItemProcessor(ItemProcessorBase):
32 | def __init__(
33 | self,
34 | transform: Dict[str, Callable[[Any], Dict]],
35 | media_symbols: List[str],
36 | tokenizer: str | Tokenizer,
37 | conv_template,
38 | ):
39 | self.transform = transform
40 | logger.info(f"transform:\n{self.transform}")
41 |
42 | self.media_symbols = media_symbols
43 | logger.info(f"media_symbols:\n{self.media_symbols}")
44 |
45 | if isinstance(tokenizer, str):
46 | self.tokenizer = Tokenizer(model_path=tokenizer)
47 | else:
48 | self.tokenizer = copy.deepcopy(tokenizer)
49 |
50 | # todo should not already exist
51 | self.tokenizer.tokenizer.add_tokens(media_symbols)
52 | self.d_media_symbol2token = {}
53 | self.d_media_token2symbol = {}
54 | for media_symbol in media_symbols:
55 | tokenized_symbol = self.tokenizer.encode(media_symbol, bos=False, eos=False)
56 | assert len(tokenized_symbol) == 1
57 | self.d_media_symbol2token[media_symbol] = tokenized_symbol[0]
58 | self.d_media_token2symbol[tokenized_symbol[0]] = media_symbol
59 |
60 | # implicit_at_beginning means media without explict location specification are arranged right after bos token
61 | # if false, then these medias are arranged at the beginning of the first question
62 | self.implicit_at_beginning = False
63 | self.conv_template = conv_template
64 |
65 | def collect_and_process_media(self, data_item):
66 | """
67 | this function receives a raw piece of data (e.g. read from `.json` data file),
68 | and returns d_media, containing the prepared media readily usable by model
69 | YOU MAY OVERRIDE THIS FUNCTION TO SUPPORT COMPLEX LOADING OF VARIOUS FORMS OF DATA
70 | """
71 | d_media = {}
72 | for media_symbol in self.media_symbols:
73 | if media_symbol in data_item:
74 | l_media = data_item[media_symbol] # a list of media paths
75 | elif media_symbol.lstrip("<|").rstrip("|>") in data_item:
76 | l_media = data_item[media_symbol.lstrip("<|").rstrip("|>")]
77 | else:
78 | l_media = []
79 | if not isinstance(l_media, list): # data with only one media, in format {"image": image_name, ...}
80 | l_media = [l_media]
81 |
82 | d_media[media_symbol] = []
83 | for media in l_media:
84 | media = self.transform[media_symbol](media)
85 | assert isinstance(media, Dict)
86 | media["type"] = media_symbol
87 | d_media[media_symbol].append(media)
88 |
89 | return d_media
90 |
91 | def replace_media_token_with_media(
92 | self, tokens: List[int], labels: Union[List[int], None], d_media: Dict[str, List]
93 | ):
94 | d_media_counter = {key: 0 for key in d_media}
95 | for i, t in enumerate(tokens):
96 | if t in self.d_media_token2symbol:
97 | media_symbol = self.d_media_token2symbol[t]
98 | media = d_media[media_symbol][d_media_counter[media_symbol]]
99 | d_media_counter[media_symbol] += 1
100 | tokens[i] = media
101 | media["to_predict"] = labels[i] > 0
102 |
103 | assert all([d_media_counter[key] == len(d_media[key]) for key in d_media])
104 |
105 | if labels is not None:
106 | return tokens, labels
107 | else:
108 | return tokens
109 |
110 | @staticmethod
111 | def insert_implicit_media_symbol_in_q1(conv_list: List[Dict], d_media: Dict):
112 | """
113 | Add the media tokens to the beginning of the first instruction from
114 | human. This logic may be more reasonable. However, it is incompatible
115 | with old-version Accessory models, which are trained with image tokens
116 | inserted directly behind the first token ().
117 | :param conv_list: [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}, ...]
118 | :param d_media: a dict of media for all media types
119 | """
120 | conv_list = copy.deepcopy(conv_list)
121 |
122 | for media_symbol, l_media in d_media.items():
123 | media_symbol_count = "".join([_["value"] for _ in conv_list if _["value"] is not None]).count(media_symbol)
124 | if media_symbol_count > 0:
125 | assert media_symbol_count == len(
126 | l_media
127 | ), f"{media_symbol_count} {media_symbol} exists in text, but {len(l_media)} actual media are given"
128 | else:
129 | conv_list[0]["value"] = (media_symbol + " ") * len(l_media) + conv_list[0]["value"]
130 |
131 | return conv_list
132 |
133 | @staticmethod
134 | def insert_implicit_media_symbol_at_beginning(conv: str, d_media: Dict):
135 | """
136 | Legacy versions of LLaMA2-Accessory handled media in a non-interleaved
137 | manner, where image tokens are inserted directly behind the first token,
138 | namely . To support interleaved media comprehension and generation,
139 | Accessory now supports the explicit specification of media occurrence,
140 | which is achieved by adding media symbols, e.g. , within the
141 | conversations. On the other hand, for media without explicit
142 | specification, this function realizes the legacy behavior to arrange
143 | them at the beginning of the conversation.
144 | :param conv: conversation
145 | :param d_media: a dict of media for all media types, for determining how
146 | many media tokens need to be inserted
147 | """
148 | conv = copy.deepcopy(conv)
149 |
150 | for media_symbol, l_media in d_media.items():
151 | media_symbol_count = conv.count(media_symbol)
152 | if media_symbol_count > 0:
153 | assert media_symbol_count == len(
154 | l_media
155 | ), f"{media_symbol_count} {media_symbol} exists in text, but {len(l_media)} actual media are given"
156 | else:
157 | conv = (media_symbol + " ") * len(l_media) + conv
158 |
159 | return conv
160 |
161 | def preprocess_item(self, data_item):
162 | return data_item
163 |
164 | def add_speaker_and_signal(self, source: List):
165 | """
166 | Given source instruction and response pieces, return the text containing the complete conversation,
167 | and the list of values that the model should learn to predict during training
168 | :param source: [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}, ...]
169 | :return: `conversation`: string containing the complete conversation;
170 | `to_predict_list`: the list of values that the model should learn to predict during training
171 | """
172 | conv = self.conv_template()
173 |
174 | for i, sentence in enumerate(source):
175 | from_str = sentence["from"]
176 | if i % 2 == 0:
177 | assert from_str.lower() in ["human"]
178 | role = conv.roles[0]
179 | elif i % 2 == 1:
180 | assert from_str.lower() in ["gpt", "assistant"]
181 | role = conv.roles[1]
182 | else:
183 | raise ValueError(f"unknown dialog role: {from_str.lower()}")
184 |
185 | value = sentence["value"]
186 |
187 | conv.append_message(role, value)
188 |
189 | processed = conv.process()
190 | conversation, pieces = processed["conv"], processed["pieces"]
191 |
192 | return conversation, pieces
193 |
194 | def process_item(self, data_item: dict, training_mode=False) -> Tuple[List, List]:
195 | data_item = self.preprocess_item(data_item)
196 |
197 | d_media = self.collect_and_process_media(data_item)
198 |
199 | source = data_item["conversations"]
200 |
201 | # implicit_at_beginning means media without explict location specification are arranged right after bos token
202 | # if false, then these medias are arranged at the beginning of the first question
203 | if not self.implicit_at_beginning:
204 | source = self.insert_implicit_media_symbol_in_q1(source, d_media)
205 |
206 | conversation, pieces = self.add_speaker_and_signal(source)
207 |
208 | if self.implicit_at_beginning:
209 | conversation = self.insert_implicit_media_symbol_at_beginning(conversation, d_media)
210 |
211 | # dialog does not need eos
212 | tokens = self.tokenizer.encode(conversation, bos=True, eos=False)
213 | labels = [-100 for _ in tokens]
214 |
215 | # check special token num as expected
216 | for media_symbol, l_media in d_media.items():
217 | media_token = self.d_media_symbol2token[media_symbol]
218 | media_token_count = tokens.count(media_token)
219 | assert media_token_count == len(l_media), (
220 | f"{media_token_count} {media_token} (for {media_symbol}) exists in tokenized conversation, "
221 | f"but {len(l_media)} actual media are given"
222 | )
223 |
224 | check_pos = 0
225 | for i, p in enumerate(pieces):
226 | if i == 0:
227 | tokenized_value = self.tokenizer.encode(p["data"], bos=True, eos=False)
228 | else:
229 | tokenized_value = self.tokenizer.encode_wo_prefix_space(p["data"])
230 |
231 | assert (
232 | tokens[check_pos : check_pos + len(tokenized_value)] == tokenized_value
233 | ), "inconsistent complete conversation and corresponding piece after tokenization"
234 |
235 | if p["predict"]:
236 | labels[check_pos : check_pos + len(tokenized_value)] = tokenized_value
237 |
238 | check_pos = check_pos + len(tokenized_value)
239 |
240 | if training_mode and all([_ <= 0 for _ in labels]): # nothing to predict
241 | raise LabelAllZeroError()
242 |
243 | # labels will be processed later by the model
244 | tokens, labels = self.replace_media_token_with_media(tokens, labels, d_media)
245 |
246 | assert len(tokens) == len(labels)
247 |
248 | if training_mode:
249 | return tokens, labels
250 | else:
251 | return tokens
252 |
253 | def predict_item_token_length(self, data_item: dict) -> int:
254 | """
255 | estimate the length of each item
256 | """
257 |
258 | if "conversations" in data_item:
259 | return sum([len(_["value"]) for _ in data_item["conversations"]])
260 | else:
261 | return 1
262 |
--------------------------------------------------------------------------------
/xllmx/data/sampler.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import logging
3 | from typing import Iterator, List, Optional
4 |
5 | import numpy as np
6 | from torch.utils.data import Sampler
7 |
8 | from xllmx.data.dataset import FinetuneConversationDataset
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | # todo too slow to be used
14 | def mild_shuffle(items: List, shuffle_factor, engine: np.random.Generator):
15 | """
16 | Perform a mild shuffle on the list of items.
17 |
18 | Args:
19 | engine: random engine
20 | items (list): The list of items to shuffle.
21 | shuffle_factor (float): max swap range is computed as len(item) * shuffle_factor.
22 |
23 | Returns:
24 | list: The mildly shuffled list.
25 | """
26 |
27 | n = len(items)
28 | swap_range = int(shuffle_factor * n)
29 | shuffled_items = [None for _ in items]
30 | cache = list(range(swap_range))
31 | for i in range(n):
32 | if i + swap_range < n:
33 | cache.append(i + swap_range)
34 | if len(cache) == 0 or cache[0] != i: # already swapped
35 | assert shuffled_items[i] is not None
36 | continue
37 | else:
38 | cache = cache[1:]
39 | if len(cache) == 0:
40 | shuffled_items[i] = items[i]
41 | else:
42 | cache_idx = engine.integers(low=0, high=len(cache))
43 | j = cache[cache_idx]
44 | del cache[cache_idx]
45 | shuffled_items[i], shuffled_items[j] = items[j], items[i]
46 |
47 | return shuffled_items
48 |
49 |
50 | class FinetuneDistSampler(Sampler):
51 | def __init__(
52 | self,
53 | dataset: FinetuneConversationDataset,
54 | num_replicas: Optional[int] = None,
55 | rank: Optional[int] = None,
56 | shuffle: bool = True,
57 | seed: int = 0,
58 | batch_size=None,
59 | acc_grad=1,
60 | length_clustering=True,
61 | allow_mixed_task_among_acc=False,
62 | ):
63 | """
64 | Distributed Sampler ensuring data in a batch are of the same type (e.g. text, image-text)
65 | :param dataset:
66 | :param num_replicas:
67 | :param rank:
68 | :param shuffle:
69 | :param seed:
70 | :param batch_size:
71 | :param acc_grad:
72 | :param length_clustering:
73 | :param allow_mixed_task_among_acc:
74 | """
75 | # super().__init__()
76 |
77 | if num_replicas is None or rank is None or rank >= num_replicas or rank < 0:
78 | raise ValueError(f"Invalid num_replicas ({num_replicas}) or rank ({rank})")
79 | assert batch_size is not None
80 |
81 | self.dataset = dataset
82 | self.num_replicas = num_replicas
83 | self.rank = rank
84 | self.shuffle = shuffle
85 | self.seed = seed
86 | self.batch_size = batch_size
87 | self.acc_grad = acc_grad
88 | self.length_clustering = length_clustering
89 | self.allow_mixed_task_among_acc = allow_mixed_task_among_acc
90 |
91 | self.epoch = 0
92 | self.start_iter = 0
93 |
94 | global_bsz_acc = batch_size * num_replicas * acc_grad
95 |
96 | group_len = defaultdict(int)
97 | for i, meta in enumerate(dataset.meta_collection):
98 | group_len[meta["type"]] += int(meta["len"] * meta.get("ratio", 1.0))
99 |
100 | group_len = {key: val // global_bsz_acc * global_bsz_acc for key, val in group_len.items()}
101 |
102 | self.total_size = sum(list(group_len.values()))
103 | assert self.total_size % num_replicas == 0
104 | self.num_samples = self.total_size // num_replicas
105 |
106 | def __iter__(self) -> Iterator:
107 | global_batch_size = self.batch_size * self.num_replicas
108 | global_bsz_acc = self.batch_size * self.num_replicas * self.acc_grad
109 | rng = np.random.default_rng(self.seed + self.epoch)
110 |
111 | group_indices_and_len = defaultdict(list)
112 |
113 | # Initialize the starting index
114 | start_idx = 0
115 |
116 | # Iterate through the list of dictionaries
117 | for i, meta in enumerate(self.dataset.meta_collection):
118 | # Calculate the ending index for the current collection
119 | end_idx = start_idx + meta["len"]
120 | indices = list(range(start_idx, end_idx))
121 | assert len(indices) == len(meta["item_len_list"])
122 | indices_and_len = [[idx, length] for idx, length in zip(indices, meta["item_len_list"])]
123 | if meta.get("ratio", 1.0) != 1.0:
124 | indices_and_len = list(rng.choice(indices_and_len, int(meta["len"] * meta["ratio"]), replace=False))
125 | logger.info(f"meta{i}: sample (ratio = {meta['ratio']}) {len(indices_and_len)} items")
126 | group_indices_and_len[meta["type"]].extend(indices_and_len)
127 |
128 | # Update the starting index for the next collection
129 | start_idx = end_idx
130 |
131 | for group_name, indices_and_len in group_indices_and_len.items():
132 | group_indices_and_len[group_name] = indices_and_len[
133 | : len(indices_and_len) // global_bsz_acc * global_bsz_acc
134 | ]
135 |
136 | if self.shuffle:
137 | group_indices = {}
138 | if self.length_clustering:
139 | for group_name, indices_and_len in group_indices_and_len.items():
140 | indices_and_len.sort(key=lambda x: x[1])
141 | group_indices[group_name] = [_[0] for _ in indices_and_len]
142 |
143 | # option1: shuffle among neighboring items
144 | for group_name, indices in group_indices.items():
145 | result = []
146 | for pos in range(0, len(indices), global_batch_size * 500):
147 | sublist = indices[pos : pos + global_batch_size * 500]
148 | rng.shuffle(sublist)
149 | result.extend(sublist)
150 | group_indices[group_name] = result
151 | # option2: mild shuffle
152 | # group_indices[group_name] = mild_shuffle(indices, 0.1, rng)
153 | # option3: do nothing
154 | # pass
155 | else:
156 | for group_name, indices_and_len in group_indices_and_len.items():
157 | rng.shuffle(indices_and_len)
158 | group_indices[group_name] = [_[0] for _ in indices_and_len]
159 |
160 | del group_indices_and_len
161 |
162 | if self.allow_mixed_task_among_acc:
163 | global_batched_indices = [
164 | indices[i : i + global_batch_size]
165 | for group_name, indices in group_indices.items()
166 | for i in range(0, len(indices), global_batch_size)
167 | ]
168 | else:
169 | global_batched_indices = []
170 | for group_name, indices in group_indices.items():
171 | group_batched_indices = [
172 | indices[i : i + global_batch_size] for i in range(0, len(indices), global_batch_size)
173 | ]
174 | rng.shuffle(group_batched_indices)
175 | group_batched_indices = [
176 | sum(group_batched_indices[i : i + self.acc_grad], start=[])
177 | for i in range(0, len(group_batched_indices), self.acc_grad)
178 | ]
179 | global_batched_indices.extend(group_batched_indices)
180 | rng.shuffle(global_batched_indices)
181 | indices = [_ for batch_indices in global_batched_indices for _ in batch_indices]
182 | else:
183 | raise NotImplementedError()
184 |
185 | assert len(indices) == self.total_size
186 |
187 | own_indices = []
188 | for start_pos in range(self.rank * self.batch_size, len(indices), self.num_replicas * self.batch_size):
189 | own_indices += indices[start_pos : start_pos + self.batch_size]
190 | # subsample
191 | assert len(own_indices) == self.num_samples
192 |
193 | if self.start_iter * self.batch_size > len(own_indices):
194 | own_indices = []
195 | else:
196 | own_indices = own_indices[self.start_iter * self.batch_size :]
197 |
198 | return iter(own_indices)
199 |
200 | def __len__(self) -> int:
201 | return self.num_samples
202 |
203 | def set_epoch(self, epoch: int, start_iter: int = 0) -> None:
204 | r"""
205 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
206 | use a different random ordering for each epoch. Otherwise, the next iteration of this
207 | sampler will yield the same ordering.
208 |
209 | Args:
210 | epoch (int): Epoch number.
211 | start_iter (int): start iter number.
212 | """
213 | self.epoch = epoch
214 | self.start_iter = start_iter
215 |
--------------------------------------------------------------------------------
/xllmx/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/model/__init__.py
--------------------------------------------------------------------------------
/xllmx/model/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/xllmx/model/tokenizer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 | from typing import List, Optional
5 |
6 | from sentencepiece import SentencePieceProcessor
7 | from transformers import AutoTokenizer
8 |
9 | __all__ = ["Tokenizer", "probe_tokenizer_path_from_pretrained"]
10 |
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class Tokenizer:
16 | def __init__(self, model_path: str):
17 | """
18 | Create a tokenizer, with inner implementation either spm or HF transformers tokenzier
19 | :param model_path:
20 | - when using spm tokenizer, should be path to a sentencepiece model with suffix `.model`
21 | - when using huggingface transformers tokenizer, should be an HF model repo or a local directory,
22 | containing tokenizer.json and tokenizer_config.json.
23 | """
24 | if model_path.endswith(".model"): # spm tokenizer
25 | self.tokenizer_type = "spm"
26 | # reload tokenizer
27 | assert os.path.isfile(model_path), model_path
28 | self.tokenizer = SentencePieceProcessor(model_file=model_path)
29 | logger.info(f"Reloaded SentencePiece model from {model_path}")
30 |
31 | # BOS / EOS token IDs
32 | self.bos_id: int = self.tokenizer.bos_id()
33 | self.eos_id: int = self.tokenizer.eos_id()
34 | assert self.tokenizer.vocab_size() == self.tokenizer.get_piece_size()
35 | else:
36 | self.tokenizer_type = "transformers"
37 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
38 | logger.info(f"load HF transformers tokenizer from {model_path}")
39 | # BOS / EOS token IDs
40 | self.bos_id: int = self.tokenizer.bos_token_id
41 | if self.bos_id is None:
42 | self.bos_id = self.tokenizer.eos_token_id
43 | self.eos_id: int = self.tokenizer.eos_token_id
44 | assert self.eos_id is not None
45 |
46 | self._probe_tokenizer_style()
47 |
48 | logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
49 |
50 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
51 | assert type(s) is str
52 | if self.tokenizer_type == "transformers":
53 | t = self.tokenizer.encode(s, truncation=False, add_special_tokens=False)
54 | else:
55 | t = self.tokenizer.encode(s)
56 | if bos:
57 | t = [self.bos_id] + t
58 | if eos:
59 | t = t + [self.eos_id]
60 | return t
61 |
62 | def encode_segment(self, s: str):
63 | s = s.lstrip(" ")
64 | if self.need_space_before_segment:
65 | return self.encode(" " + s, bos=False, eos=False)
66 | else:
67 | return self.encode(s, bos=False, eos=False)
68 |
69 | def encode_wo_prefix_space(self, s: str):
70 | if self.need_space_before_segment:
71 | return self.encode(s, bos=False, eos=False)
72 | else:
73 | # prefix chars that, when preceding other strings without seperator in between,
74 | # are relatively more likely to be tokenized independently rather than getting
75 | # merged into the following strings.
76 | l_prefix = ["@", "\n", "\\", "=", ">", "`"]
77 | for prefix in l_prefix:
78 | prefix_tokens = self.encode(prefix, bos=False, eos=False)
79 | cat_tokens = self.encode(prefix + s, bos=False, eos=False)
80 | if cat_tokens[: len(prefix_tokens)] == prefix_tokens:
81 | return cat_tokens[len(prefix_tokens) :]
82 |
83 | raise NotImplementedError(
84 | f"All prefixes are merged into {s} during tokenization,"
85 | f"This is wierd behavior, please open an issue to report this problem",
86 | )
87 |
88 | def _probe_tokenizer_style(self):
89 | """
90 | Given a sentence, e.g. "Hi my darling", some tokenizers (e.g. LLaMA's) will pose the following behavior:
91 | >>> # leading characters will be treated as if there were an " " in the beginning
92 | >>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode("my darling")
93 | >>> # leading space " " is redundant and should not be added
94 | >>> tokenizer.encode("Hi my darling") != tokenizer.encode("Hi") + tokenizer.encode(" my darling")
95 | However, some others (e.g. InternLM's) will behave differently:
96 | >>> # leading space " " has to be explicitly added
97 | >>> tokenizer.encode("Hi my darling") == tokenizer.encode("Hi") + tokenizer.encode(" my darling")
98 | Knowing which style the tokenizer takes is necessary when tokenzing a segment cut from the complete
99 | text, so that the result is the same as the corresponding part in the tokenized original text.
100 | """
101 | sentence1 = self.encode("Hi my darling", bos=False, eos=False)
102 | sentence2 = self.encode("my darling", bos=False, eos=False)
103 | if sentence1[-len(sentence2) :] == sentence2:
104 | self.need_space_before_segment = False
105 | else:
106 | sentence3 = self.encode(" my darling", bos=False, eos=False)
107 | assert sentence1[-len(sentence3) :] == sentence3
108 | self.need_space_before_segment = True
109 |
110 | def decode(self, t: List[int]) -> str:
111 | return self.tokenizer.decode(t)
112 |
113 | def save(self, save_dir: str):
114 | if self.tokenizer_type == "transformers":
115 | self.tokenizer.save_pretrained(save_dir)
116 | else:
117 | with open(Path(save_dir) / "tokenizer.model", "wb") as f:
118 | f.write(self.tokenizer.serialized_model_proto())
119 |
120 | @property
121 | def n_words(self):
122 | if self.tokenizer_type == "spm":
123 | return self.tokenizer.vocab_size()
124 | elif self.tokenizer_type == "transformers":
125 | return len(self.tokenizer)
126 | else:
127 | raise RuntimeError
128 |
129 |
130 | def probe_tokenizer_path_from_pretrained(pretrained_path: str):
131 | tokenizer_path = None
132 |
133 | # try find spm-style tokenizer
134 | logger.info(f"trying to find sentencepiece-style tokenizer at {Path(pretrained_path) / 'tokenizer.model'}")
135 | if (Path(pretrained_path) / "tokenizer.model").exists():
136 | logger.info(f"Found {Path(pretrained_path) / 'tokenizer.model'}, use it.")
137 | tokenizer_path = str(Path(pretrained_path) / "tokenizer.model")
138 | else:
139 | logger.info("Not Found")
140 |
141 | # then try huggingface style
142 | if tokenizer_path is None:
143 | logger.info(
144 | f"trying to find huggingface-style tokenizer at "
145 | f"{Path(pretrained_path) / '(tokenizer.json, tokenizer_config.json)'}"
146 | )
147 | if (Path(pretrained_path) / "tokenizer.json").exists() and (
148 | Path(pretrained_path) / "tokenizer_config.json"
149 | ).exists():
150 | logger.info(f"Found {Path(pretrained_path) / '(tokenizer.json, tokenizer_config.json)'}, use them.")
151 | tokenizer_path = pretrained_path
152 | else:
153 | logger.info("Not Found")
154 | if tokenizer_path is None:
155 | logger.info("No usable tokenizer found")
156 | return tokenizer_path
157 |
--------------------------------------------------------------------------------
/xllmx/solvers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-mGPT/104abe453ec1acca5863698629c4db2111b0b3fc/xllmx/solvers/__init__.py
--------------------------------------------------------------------------------
/xllmx/solvers/finetune/__init__.py:
--------------------------------------------------------------------------------
1 | from .finetune import FinetuneSolverBase
2 |
--------------------------------------------------------------------------------
/xllmx/util/__init__.py:
--------------------------------------------------------------------------------
1 | from . import ckpt, dist
2 |
--------------------------------------------------------------------------------
/xllmx/util/ckpt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import shutil
5 | from typing import Dict, Optional
6 |
7 | import torch
8 | from torch import distributed as dist
9 | from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, StateDictType
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def split_ckpt_str_into_epoch_iter(ckpt_str: str):
15 | # divide ckpt directory names into epoch and iter parts
16 | parts = ckpt_str.split("-")
17 | epoch = int(parts[0].replace("epoch", ""))
18 | if len(parts) == 2:
19 | iter_part = int(parts[1].replace("iter", ""))
20 | else:
21 | iter_part = None
22 | return epoch, iter_part
23 |
24 |
25 | def remove_early_ckpts(out_dir, max_keep=2):
26 |
27 | if max_keep <= 0:
28 | return
29 |
30 | def ckpt_sort_key(s):
31 | # divide ckpt directory names into epoch and iter parts
32 | epoch, iteration = split_ckpt_str_into_epoch_iter(s)
33 | if iteration is None:
34 | iteration = float("inf")
35 | return epoch, iteration
36 |
37 | existing_checkpoints = [_ for _ in os.listdir(out_dir) if "epoch" in _]
38 | existing_checkpoints = sorted(existing_checkpoints, key=ckpt_sort_key, reverse=True)
39 |
40 | for dir_to_remove in existing_checkpoints[max_keep:]:
41 | dir_to_remove = os.path.join(out_dir, dir_to_remove)
42 | shutil.rmtree(dir_to_remove)
43 | logger.info(f"Deleted {dir_to_remove}")
44 |
45 |
46 | def save(
47 | output_dir,
48 | is_main_process,
49 | model: FSDP,
50 | optimizer: Optional[torch.optim.Optimizer] = None,
51 | tokenizer=None,
52 | args=None,
53 | epoch=None,
54 | iteration=None,
55 | additional_rank_common: Optional[Dict] = None,
56 | additional_rank_specific: Optional[Dict] = None,
57 | max_keep=2,
58 | ):
59 | save_name = f"epoch{epoch}"
60 | if iteration is not None:
61 | save_name += f"-iter{iteration}"
62 | save_dir = os.path.join(output_dir, save_name)
63 |
64 | os.makedirs(save_dir, exist_ok=True)
65 |
66 | # save model
67 | with FSDP.state_dict_type(
68 | model,
69 | StateDictType.FULL_STATE_DICT,
70 | FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
71 | ):
72 | # run saving in separate functions to save memory
73 | def _save_model():
74 | save_dtype = {
75 | "fp16": torch.float16,
76 | "bf16": torch.bfloat16,
77 | "tf32": torch.float,
78 | }[
79 | args.precision
80 | ] # todo make saving precision optional
81 | if getattr(args, "only_save_trainable", False):
82 | model_trainable_params = model.get_trainable_params()
83 | model_trainable_params = [
84 | ".".join([_ for _ in key.split(".") if not _.startswith("_")])
85 | for key in model_trainable_params.keys()
86 | ]
87 | consolidated_model_state_dict = {
88 | key: val.to(save_dtype) for key, val in model.state_dict().items() if key in model_trainable_params
89 | }
90 | else:
91 | consolidated_model_state_dict = {key: val.to(save_dtype) for key, val in model.state_dict().items()}
92 |
93 | if is_main_process:
94 | model.save_pretrained(save_dir, state_dict=consolidated_model_state_dict)
95 |
96 | _save_model()
97 | logger.info("model saved")
98 |
99 | # save optimizer
100 | if optimizer is not None:
101 | with FSDP.state_dict_type(
102 | model,
103 | StateDictType.LOCAL_STATE_DICT,
104 | ):
105 | opt_path = os.path.join(
106 | save_dir,
107 | f"optimizer.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth",
108 | )
109 | torch.save(optimizer.state_dict(), opt_path)
110 | logger.info("optimizer saved")
111 | else:
112 | logger.info("optimizer is None, skip saving")
113 |
114 | if additional_rank_specific is not None:
115 | torch.save(
116 | additional_rank_specific,
117 | os.path.join(save_dir, f"additional.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth"),
118 | )
119 | logger.info(f"additional_rank_specific {list(additional_rank_specific.keys())} saved")
120 |
121 | if not is_main_process:
122 | dist.barrier()
123 | return
124 |
125 | # =========The followings are for main process only=========
126 | if tokenizer is not None:
127 | tokenizer.save(save_dir)
128 | logger.info("tokenizer saved")
129 | else:
130 | logger.info("tokenizer is None, skip saving")
131 |
132 | if args is not None:
133 | with open(os.path.join(save_dir, "args.json"), "w") as f:
134 | json.dump(vars(args), f, indent=2)
135 | logger.info("args saved")
136 | else:
137 | logger.info("args is None, skip saving")
138 |
139 | if additional_rank_common is not None:
140 | torch.save(additional_rank_common, os.path.join(save_dir, "additional_rank_common.pth"))
141 | logger.info(f"additional_resources {list(additional_rank_common.keys())} saved")
142 |
143 | remove_early_ckpts(output_dir, max_keep=max_keep)
144 |
145 | dist.barrier()
146 | return
147 |
--------------------------------------------------------------------------------
/xllmx/util/dist.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import os
4 | import socket
5 | import subprocess
6 | import time
7 | from types import SimpleNamespace
8 |
9 | import torch
10 | import torch.distributed as dist
11 |
12 | from xllmx.util.misc import random_seed
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def find_free_port(start_port: int, end_port: int):
18 | """
19 | Find a free port within the specified range.
20 | """
21 | for port in range(start_port, end_port):
22 | try:
23 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
24 | s.bind(("", port)) # Try to bind to the port
25 | s.close() # Close the socket if successful
26 | return port
27 | except OSError as e:
28 | # print(f"Port {port} is in use, trying next port.")
29 | continue
30 | raise RuntimeError(f"No free ports found in range {start_port}-{end_port}")
31 |
32 |
33 | def init_distributed_mode(args=SimpleNamespace()):
34 | random_seed(getattr(args, "seed", 0))
35 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ:
36 | args.world_size = int(os.environ["WORLD_SIZE"])
37 | args.rank = int(os.environ["RANK"])
38 | args.gpu = int(os.environ["LOCAL_RANK"])
39 | args.local_rank = args.gpu
40 | args.dist_url = "env://"
41 | elif "SLURM_PROCID" in os.environ:
42 | os.environ["MASTER_PORT"] = "8966"
43 | while "MASTER_ADDR" not in os.environ or len(os.environ["MASTER_ADDR"].strip()) == 0:
44 | os.environ["MASTER_ADDR"] = (
45 | subprocess.check_output(
46 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
47 | shell=True,
48 | )
49 | .decode()
50 | .strip()
51 | )
52 | time.sleep(1)
53 | print(os.environ["MASTER_ADDR"])
54 | args.world_size = int(os.environ["SLURM_NPROCS"])
55 | args.rank = int(os.environ["SLURM_PROCID"])
56 | args.gpu = args.rank % torch.cuda.device_count()
57 | args.local_rank = args.gpu
58 | args.dist_url = "env://"
59 | os.environ["LOCAL_RANK"] = str(args.gpu)
60 | os.environ["WORLD_SIZE"] = str(args.world_size)
61 | os.environ["RANK"] = str(args.rank)
62 | else:
63 | os.environ["MASTER_ADDR"] = "127.0.0.1"
64 | os.environ["MASTER_PORT"] = str(find_free_port(9000, 10000))
65 | os.environ["RANK"] = "0"
66 | os.environ["LOCAL_RANK"] = "0"
67 | os.environ["WORLD_SIZE"] = "1"
68 | args.rank = 0
69 | args.gpu = args.local_rank = 0
70 | args.world_size = 1
71 | args.dist_url = "env://"
72 |
73 | args.distributed = True
74 |
75 | torch.cuda.set_device(args.gpu)
76 | args.dist_backend = "nccl"
77 | print("| distributed init (rank {}): {}, gpu {}".format(args.rank, args.dist_url, args.gpu), flush=True)
78 | torch.distributed.init_process_group(
79 | backend=args.dist_backend,
80 | init_method=args.dist_url,
81 | world_size=args.world_size,
82 | rank=args.rank,
83 | timeout=datetime.timedelta(seconds=2 * 60 * 60),
84 | )
85 | torch.distributed.barrier()
86 |
87 |
88 | def all_reduce_mean(x, group=None):
89 | world_size = dist.get_world_size(group=group)
90 | if world_size > 1:
91 | if isinstance(x, torch.Tensor):
92 | x_reduce = x.clone().cuda()
93 | else:
94 | x_reduce = torch.tensor(x).cuda()
95 | dist.all_reduce(x_reduce, group=group)
96 | x_reduce /= world_size
97 | return x_reduce.item()
98 | else:
99 | return x
100 |
--------------------------------------------------------------------------------
/xllmx/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def adjust_learning_rate(optimizer, it, args):
5 | """Decay the learning rate with half-cycle cosine after warmup"""
6 | if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps
7 | lr = args.lr * it / args.warmup_iters
8 | elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate
9 | lr = args.min_lr
10 | else: # 3) in between, use cosine decay down to min learning rate
11 | decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters)
12 | assert 0 <= decay_ratio <= 1
13 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
14 | lr = args.min_lr + (args.lr - args.min_lr) * coeff
15 |
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 |
23 |
24 | def adjust_learning_rate_epoch(optimizer, epoch, args):
25 | """Decay the learning rate with half-cycle cosine after warmup"""
26 | if epoch < args.warmup_epochs:
27 | lr = args.lr * epoch / args.warmup_epochs
28 | else:
29 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
30 | 1.0 + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))
31 | )
32 | for param_group in optimizer.param_groups:
33 | if "lr_scale" in param_group:
34 | param_group["lr"] = lr * param_group["lr_scale"]
35 | else:
36 | param_group["lr"] = lr
37 | return lr
38 |
--------------------------------------------------------------------------------
/xllmx/util/misc.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, deque
2 | import datetime
3 | import logging
4 | import random
5 | import time
6 |
7 | from fairscale.nn.model_parallel import initialize as fs_init
8 | import numpy as np
9 | import torch
10 | import torch.distributed as dist
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def random_seed(seed=0):
16 | random.seed(seed)
17 | torch.random.manual_seed(seed)
18 | np.random.seed(seed)
19 |
20 |
21 | class SmoothedValue(object):
22 | """Track a series of values and provide access to smoothed values over a
23 | window or the global series average.
24 | """
25 |
26 | def __init__(self, window_size=1000, fmt=None):
27 | if fmt is None:
28 | fmt = "{avg:.4f} ({global_avg:.4f})"
29 | self.deque = deque(maxlen=window_size)
30 | self.total = 0.0
31 | self.count = 0
32 | self.fmt = fmt
33 |
34 | def update(self, value, n=1):
35 | self.deque.append(value)
36 | self.count += n
37 | self.total += value * n
38 |
39 | def synchronize_between_processes(self):
40 | """
41 | Warning: does not synchronize the deque!
42 | """
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
75 | )
76 |
77 |
78 | class MetricLogger(object):
79 | def __init__(self, delimiter="\t"):
80 | self.meters = defaultdict(SmoothedValue)
81 | self.delimiter = delimiter
82 |
83 | def update(self, **kwargs):
84 | for k, v in kwargs.items():
85 | if v is None:
86 | continue
87 | elif isinstance(v, (torch.Tensor, float, int)):
88 | self.meters[k].update(v.item() if isinstance(v, torch.Tensor) else v)
89 | elif isinstance(v, list):
90 | for i, sub_v in enumerate(v):
91 | self.meters[f"{k}_{i}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v)
92 | elif isinstance(v, dict):
93 | for sub_key, sub_v in v.items():
94 | self.meters[f"{k}_{sub_key}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v)
95 | else:
96 | raise TypeError(f"Unsupported type {type(v)} for metric {k}")
97 |
98 | def __str__(self):
99 | loss_str = []
100 | for name, meter in self.meters.items():
101 | loss_str.append("{}: {}".format(name, str(meter)))
102 | return self.delimiter.join(loss_str)
103 |
104 | def synchronize_between_processes(self):
105 | for meter in self.meters.values():
106 | meter.synchronize_between_processes()
107 |
108 | def add_meter(self, name, meter):
109 | self.meters[name] = meter
110 |
111 | def log_every(self, iterable, print_freq, header=None, start_iter=0, samples_per_iter=None):
112 | i = start_iter
113 | if not header:
114 | header = ""
115 | start_time = time.time()
116 | end = time.time()
117 | iter_time = SmoothedValue(fmt="{avg:.4f}")
118 | data_time = SmoothedValue(fmt="{avg:.4f}")
119 | log_msg = [header, "[{0" + "}/{1}]", "{meters}", "time: {time}", "data: {data}"]
120 | if samples_per_iter is not None:
121 | log_msg.append("samples/sec: {samples_per_sec:.2f}")
122 | if torch.cuda.is_available():
123 | log_msg.append("max mem: {memory:.0f}")
124 | log_msg = self.delimiter.join(log_msg)
125 | MB = 1024.0 * 1024.0
126 | for obj in iterable:
127 | data_time.update(time.time() - end)
128 | yield obj
129 | iter_time.update(time.time() - end)
130 | if i % print_freq == 0:
131 | try:
132 | total_len = len(iterable)
133 | except:
134 | total_len = "unknown"
135 |
136 | msg_kwargs = {
137 | "meters": str(self),
138 | "time": str(iter_time),
139 | "data": str(data_time),
140 | }
141 | if samples_per_iter is not None:
142 | msg_kwargs["samples_per_sec"] = samples_per_iter / iter_time.avg
143 | if torch.cuda.is_available():
144 | msg_kwargs["memory"] = torch.cuda.max_memory_allocated() / MB
145 |
146 | logger.info(log_msg.format(i, total_len, **msg_kwargs))
147 | i += 1
148 | end = time.time()
149 | total_time = time.time() - start_time
150 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
151 | logger.info("{} Total time: {}".format(header, total_time_str))
152 |
153 |
154 | def add_weight_decay(model, lr, weight_decay=1e-5):
155 | decay = []
156 | no_decay = []
157 | for name, param in model.named_parameters():
158 | if not param.requires_grad:
159 | continue # frozen weights
160 | if name.endswith(".bias") or name.endswith("norm.weight"):
161 | no_decay.append(param)
162 | else:
163 | decay.append(param)
164 | return [
165 | {"params": no_decay, "lr": lr, "weight_decay": weight_decay},
166 | {"params": decay, "lr": lr, "weight_decay": weight_decay},
167 | ]
168 |
169 |
170 | def broadcast_nonmp_parameters(model):
171 | if fs_init.get_model_parallel_world_size() == 1:
172 | return
173 | logger.info("starting broadcast non-model-parallel parameters within model parallel group")
174 | memo = set()
175 | modules = model.named_modules(prefix="", remove_duplicate=True)
176 | for module_prefix, module in modules:
177 | members = dict(module._parameters.items())
178 | for k, v in members.items():
179 | name = module_prefix + ("." if module_prefix else "") + k
180 | if v is None or v in memo:
181 | continue
182 | if getattr(v, "model_parallel", False):
183 | logger.info(f"ignore: {name}")
184 | continue
185 | memo.add(v)
186 | dist.broadcast(v, src=fs_init.get_model_parallel_src_rank(), group=fs_init.get_model_parallel_group())
187 | logger.info("braodcast done")
188 |
189 |
190 | def mark_mp_params(model: torch.nn.Module):
191 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
192 |
193 | for m in model.modules():
194 | if isinstance(m, ColumnParallelLinear):
195 | m.weight.model_parallel = True
196 | if m.bias is not None:
197 | m.bias.model_parallel = True
198 |
199 | if isinstance(m, RowParallelLinear):
200 | m.weight.model_parallel = True
201 |
202 | if isinstance(m, ParallelEmbedding):
203 | m.weight.model_parallel = True
204 |
205 |
206 | def print_param_status(model: torch.nn.Module) -> None:
207 | require_grad_set = []
208 | no_grad_set = []
209 | for name, param in model.named_parameters():
210 | if param.requires_grad:
211 | require_grad_set.append((name, param))
212 | else:
213 | no_grad_set.append((name, param))
214 |
215 | logger.info("Params that require gradient:\n")
216 | for name, param in require_grad_set:
217 | model_parallel = getattr(param, "model_parallel", False)
218 | logger.info(
219 | f"Param {name}: requires_grad {param.requires_grad}, local_size {param.shape}, model_parallel {model_parallel}, dtype {param.dtype}"
220 | )
221 |
222 | logger.info("\nParams that do not require gradient:\n")
223 | for name, param in no_grad_set:
224 | model_parallel = getattr(param, "model_parallel", False)
225 | logger.info(
226 | f"Param {name}: requires_grad {param.requires_grad}, local_size {param.shape}, model_parallel {model_parallel}, dtype {param.dtype}"
227 | )
228 |
--------------------------------------------------------------------------------
/xllmx/util/tensor_type.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def promote_param_to_fp32(param: nn.Parameter) -> None:
6 | if param.is_floating_point() and torch.finfo(param.dtype).bits < 32:
7 | param.data = param.data.float()
8 | if param.is_complex() and torch.finfo(param.dtype).bits < 32:
9 | param.data = param.data.to(torch.complex64)
10 |
--------------------------------------------------------------------------------