├── .gitignore ├── LICENSE ├── README.md ├── configs ├── clm_models │ ├── agent_got.yaml │ └── llm_qwen25_vl_3b_lora.yaml └── tokenizer │ └── qwen25_vl_tokenizer_token64.yaml ├── examples ├── hat.jpg ├── strawberry.jpg └── tool.png ├── figures ├── architecture.jpg ├── interactive.jpg └── teaser.jpg ├── got ├── __init__.py ├── models │ ├── __init__.py │ ├── got_model.py │ ├── peft_models.py │ ├── projector.py │ └── utils.py └── processer │ └── qwen25_vl_processor.py ├── inference.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained/ 2 | .DS_Store 3 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Rongyao Fang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing 2 |
3 | Home 4 | 5 | 6 | [Rongyao Fang](https://scholar.google.com/citations?user=FtH3CW4AAAAJ&hl=en)1\*, [Chengqi Duan](https://scholar.google.com/citations?user=r9qb4ZwAAAAJ&hl=zh-CN)2\*, [Kun Wang]()3, [Linjiang Huang](https://leonhlj.github.io/)6, [Hao Li](https://scholar.google.com/citations?user=qHqQsY4AAAAJ&hl=zh-CN)1,4, [Shilin Yan](https://scholar.google.com/citations?user=2VhjOykAAAAJ&hl=zh-CN), [Hao Tian]()3, [Xingyu Zeng]()3, [Rui Zhao]()3, [Jifeng Dai](https://jifengdai.org/)4,5, [Xihui Liu](https://xh-liu.github.io/)2 :envelope:, [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)1 :envelope: 7 | 8 | 1CUHK MMLab, 2HKU MMLab, 3SenseTime, 4Shanghai AI Laboratory, 5Tsinghua University, 6Beihang University 9 | 10 | *Equal contribution, :envelope:Corresponding authors 11 |
12 | 13 |
14 | GoT Framework 15 |
16 |
17 |
18 | Paper • 19 | Introduction • 20 | Datasets • 21 | Model • 22 | Results • 23 | 🤗 Hugging Face • 24 | License 25 |
26 | 27 | ## Introduction 28 | 29 | We present **Generation Chain-of-Thought (GoT)**, a novel paradigm that enables generation and editing through an explicit language reasoning process before outputting images. This approach transforms conventional text-to-image generation and editing into a reasoning-guided framework that analyzes semantic relationships and spatial arrangements. 30 | 31 | GoT pioneers a new direction for reasoning-driven visual generation and editing, producing images that better align with human intent through: 32 | 33 | - **Semantic-Spatial Reasoning**: Integrates both semantic understanding and explicit spatial coordinates 34 | - **Unified Framework**: Handles both image generation and editing with the same architecture 35 | 36 | ## Released Datasets 37 | 38 | | Dataset | Link | Amount | 39 | |---------|------|--------| 40 | | **Laion-Aesthetics-High-Resolution-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/Laion-Aesthetics-High-Resolution-GoT) | 3.77M | 41 | | **JourneyDB-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/JourneyDB-GoT) | 4.09M | 42 | | **OmniEdit-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/OmniEdit-GoT) | 736K | 43 | 44 | ## Dataset Features 45 | 46 | ### Laion-Aesthetics-High-Resolution-GoT 47 | - 3.77 million High-quality images filtered for sizes larger than 512 pixels from Laion-Aesthetics 48 | - Prompts and GoT descriptions from Qwen2-VL 49 | - Prompts averaging 110.81 characters 50 | - GoT descriptions averaging 811.56 characters 51 | - 3.78 bounding boxes per image on average 52 | 53 | ### JourneyDB-GoT 54 | - 4.09 million high-quality AI-generated images 55 | - Prompts and GoT descriptions from Qwen2-VL 56 | - Prompts averaging 149.78 characters 57 | - GoT descriptions averaging 906.01 characters 58 | - 4.09 bounding boxes per image on average 59 | - Please download the images from [JourneyDB dataset](https://opendatalab.com/OpenDataLab/JourneyDB/tree/main/raw/JourneyDB/train/imgs) 60 | 61 | ### OmniEdit-GoT 62 | - 736K high-quality image editing samples from OmniEdit 63 | - Diverse editing operations (addition, removal, swap, attribute changes, style transfer) 64 | - Detailed reasoning chains with step-by-step editing processes 65 | - Precise spatial coordinate annotations for editing regions 66 | - Please download the images from [OmniEdit dataset](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M) 67 | 68 | ## Released Model: GoT Framework 69 | 70 | | Model | Link | Architecture | 71 | |------------|------|----------------------| 72 | | **GoT-6B** | [🤗 HuggingFace](https://huggingface.co/LucasFang/GoT-6B) | Qwen2.5-VL-3B + SDXL | 73 | 74 | ## Model Features 75 | 76 |
77 | GoT Architecture 78 |
79 | 80 | Our GoT framework consists of two key components: 81 | 82 | 1. **Semantic-Spatial MLLM**: Generates detailed reasoning chains with spatial information using Qwen2.5-VL as the backbone 83 | 2. **SSGM Diffusion Module**: Leverages the semantic guidance, spatial layouts, and reference images to create high-quality visual outputs 84 | 85 | The Semantic-Spatial Guidance Module (SSGM) combines three guidance pathways: 86 | - **Semantic Guidance**: Captures relationships and attributes 87 | - **Spatial Guidance**: Controls precise object placement 88 | - **Reference Guidance**: Provides context for editing tasks 89 | 90 | ## Results 91 | 92 | ### Text-to-Image Generation 93 | 94 | GoT achieves state-of-the-art performance on the GenEval benchmark, particularly excelling in composition tasks: 95 | 96 |
97 | 98 | | Method | Architecture | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Attr. Binding | 99 | |--------|--------------|---------|-------------|----------|----------|--------|----------|---------------| 100 | | SD-XL | Unet+CLIP | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 | 101 | | SD3 | MMDIT+CLIP+T5 | 0.62 | 0.98 | 0.74 | 0.63 | 0.67 | 0.34 | 0.36 | 102 | | Emu3-Gen | Autoregressive | 0.54 | 0.98 | 0.71 | 0.34 | 0.81 | 0.17 | 0.21 | 103 | | Janus | Autoregressive | 0.61 | 0.97 | 0.68 | 0.30 | 0.84 | 0.46 | 0.42 | 104 | | JanusFlow | Autoregressive | 0.63 | 0.97 | 0.59 | 0.45 | 0.83 | 0.53 | 0.42 | 105 | | **GoT Framework** | Unet+Qwen2.5-VL | **0.64** | **0.99** | 0.69 | **0.67** | **0.85** | 0.34 | 0.27 | 106 | 107 |
108 | 109 | ### Image Editing 110 | 111 | Our approach also demonstrates superior performance on image editing benchmarks: 112 | 113 |
114 | 115 | | Method | Emu-Edit | | ImagenHub | Reason-Edit | 116 | |--------|----------|--------|-----------|------------| 117 | | | CLIP-I | CLIP-T | GPT-4o Eval. | GPT-4o Eval. | 118 | | IP2P | 0.834 | 0.219 | 0.308 | 0.286 | 119 | | MagicBrush | 0.838 | 0.222 | 0.513 | 0.334 | 120 | | SEED-X | 0.825 | 0.272 | 0.166 | 0.239 | 121 | | CosXL-Edit | 0.860 | 0.274 | 0.464 | 0.325 | 122 | | **GoT Framework** | **0.864** | **0.276** | **0.533** | 0.561 | 123 | 124 |
125 | 126 | ## Usage 127 | 128 | ### Dependencies 129 | - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 130 | - [PyTorch >=2.0.1](https://pytorch.org/) 131 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 132 | 133 | ### Installation 134 | Clone the repo and install dependent packages 135 | 136 | ```bash 137 | git clone git@github.com:rongyaofang/GoT.git 138 | cd GoT 139 | pip install -r requirements.txt 140 | ``` 141 | 142 | ### Model Weights 143 | Place the required model weights in the `./pretrained` directory as follows: 144 | 145 | 1. GoT-6B model weights 146 | 2. [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) 147 | 3. [Stable Diffusion XL Base 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 148 | 149 | Your directory structure should match the following: 150 | 151 | ``` 152 | GoT 153 | ├── pretrained 154 | │ ├── GoT-6B 155 | │ ├── Qwen2.5-VL-3B-Instruct 156 | │ └── stable-diffusion-xl-base-1.0 157 | ├── ... 158 | ``` 159 | 160 | ### Inference 161 | Follow the instructions in the [inference notebook](https://github.com/rongyaofang/GoT/blob/main/inference.ipynb) 162 | 163 | ## License 164 | 165 | This code is released under the MIT License. 166 | 167 | ## Citation 168 | 169 | If you find this work helpful, please consider citing: 170 | 171 | ``` 172 | @article{fang2025got, 173 | title={GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing}, 174 | author={Fang, Rongyao and Duan, Chengqi and Wang, Kun and Huang, Linjiang and Li, Hao and Yan, Shilin and Tian, Hao and Zeng, Xingyu and Zhao, Rui and Dai, Jifeng and Liu, Xihui and Li, Hongsheng}, 175 | journal={arXiv preprint arXiv:2503.10639}, 176 | year={2025} 177 | } 178 | ``` 179 | 180 | ## Contact 181 | 182 | If you have any questions, please raise an issue or contact us at [rongyaofang@gmail.com](mailto:rongyaofang@gmail.com). 183 | -------------------------------------------------------------------------------- /configs/clm_models/agent_got.yaml: -------------------------------------------------------------------------------- 1 | _target_: got.models.got_model.GenCot.from_pretrained 2 | output_projector: 3 | _target_: got.models.projector.LinearProjector 4 | in_hidden_size: 2048 5 | out_hidden_size: 2048 6 | 7 | output_projector_add: 8 | _target_: got.models.projector.LinearProjector 9 | in_hidden_size: 2048 10 | out_hidden_size: 1280 11 | 12 | scheduler: 13 | _target_: diffusers.DDPMScheduler.from_pretrained 14 | pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0 15 | subfolder: scheduler 16 | 17 | vae: 18 | _target_: diffusers.AutoencoderKL.from_pretrained 19 | pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0 20 | subfolder: vae 21 | 22 | unet: 23 | _target_: diffusers.UNet2DConditionModel.from_pretrained 24 | pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0 25 | subfolder: unet 26 | 27 | processor: 28 | _target_: got.processer.qwen25_vl_processor.get_processor 29 | model_name: pretrained/Qwen2.5-VL-3B-Instruct 30 | add_gen_token_num: 64 31 | 32 | num_img_out_tokens: 64 33 | img_gen_start_id: 151667 34 | -------------------------------------------------------------------------------- /configs/clm_models/llm_qwen25_vl_3b_lora.yaml: -------------------------------------------------------------------------------- 1 | _target_: got.models.peft_models.get_peft_model_without_resize_embedding 2 | model: 3 | _target_: transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained 4 | pretrained_model_name_or_path: pretrained/Qwen2.5-VL-3B-Instruct 5 | peft_config: 6 | _target_: peft.LoraConfig 7 | _convert_: object 8 | r: 32 9 | lora_alpha: 32 10 | lora_dropout: 0.05 11 | target_modules: 12 | - q_proj 13 | - v_proj 14 | - k_proj 15 | - o_proj 16 | - gate_proj 17 | - down_proj 18 | - up_proj 19 | modules_to_save: 20 | - embed_tokens 21 | - lm_head 22 | - input_layernorm 23 | - post_attention_layernorm 24 | task_type: CAUSAL_LM 25 | -------------------------------------------------------------------------------- /configs/tokenizer/qwen25_vl_tokenizer_token64.yaml: -------------------------------------------------------------------------------- 1 | _target_: got.processer.qwen25_vl_processor.get_processor 2 | model_name: pretrained/Qwen2.5-VL-3B-Instruct 3 | add_gen_token_num: 64 -------------------------------------------------------------------------------- /examples/hat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/examples/hat.jpg -------------------------------------------------------------------------------- /examples/strawberry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/examples/strawberry.jpg -------------------------------------------------------------------------------- /examples/tool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/examples/tool.png -------------------------------------------------------------------------------- /figures/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/figures/architecture.jpg -------------------------------------------------------------------------------- /figures/interactive.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/figures/interactive.jpg -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/figures/teaser.jpg -------------------------------------------------------------------------------- /got/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/got/__init__.py -------------------------------------------------------------------------------- /got/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/got/models/__init__.py -------------------------------------------------------------------------------- /got/models/got_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from PIL import Image, ImageDraw 5 | from torchvision import transforms 6 | from transformers import StoppingCriteriaList 7 | from diffusers.utils.import_utils import is_xformers_available 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from tqdm import tqdm 10 | from .utils import ( 11 | IMG_TOKEN, BOI_TOKEN, EOI_TOKEN, EOS_TOKEN, BOV_TOKEN, EOV_TOKEN, IMG_PAD_TOKEN, 12 | parse_coordinates_colors, StopOnToken 13 | ) 14 | 15 | 16 | class GenCot(nn.Module): 17 | def __init__(self, mllm, output_projector, output_projector_add, scheduler, vae, unet, processor, 18 | num_img_out_tokens=64, img_gen_start_id=151667, box_start_id=151648, box_end_id=151649) -> None: 19 | super().__init__() 20 | self.mllm = mllm # qwen25-vl model 21 | self.output_projector = output_projector 22 | self.vae = vae 23 | self.unet = unet 24 | self.scheduler = scheduler 25 | self.output_projector_add = output_projector_add 26 | 27 | # uses an additional image for conditioning. 28 | # it uses 12 channels (instead of 4) in the first (conv) layer of the UNet. 29 | in_channels = 12 30 | self.unet.register_to_config(in_channels=in_channels) 31 | 32 | with torch.no_grad(): 33 | conv = torch.nn.Conv2d(in_channels, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, 34 | self.unet.conv_in.stride, self.unet.conv_in.padding) 35 | conv.weight.zero_() 36 | conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) 37 | self.unet.conv_in = conv 38 | self.vae.requires_grad_(False) 39 | self.vae_batch = 1 40 | 41 | if is_xformers_available(): 42 | import xformers 43 | unet.enable_xformers_memory_efficient_attention() 44 | 45 | self.img_gen_start_id = img_gen_start_id 46 | self.num_img_out_tokens = num_img_out_tokens 47 | self.box_start_id = box_start_id 48 | self.box_end_id = box_end_id 49 | self.diffusion_transform = None 50 | self.source_transform = None 51 | self.processor = processor 52 | 53 | def _get_add_time_ids( 54 | self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None 55 | ): 56 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 57 | 58 | passed_add_embed_dim = ( 59 | self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim 60 | ) 61 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 62 | 63 | if expected_add_embed_dim != passed_add_embed_dim: 64 | raise ValueError( 65 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 66 | ) 67 | 68 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 69 | return add_time_ids 70 | 71 | @torch.no_grad() 72 | def generate(self, 73 | text_input, 74 | image=None, 75 | max_new_tokens=1024, 76 | num_inference_steps=50, 77 | guidance_scale=7.5, 78 | image_guidance_scale=1.0, 79 | cond_image_guidance_scale=4.0, 80 | height=1024, 81 | width=1024, 82 | input_token_num=256, 83 | do_classifier_free_guidance=True, 84 | crops_coords_top_left=(0, 0), 85 | prompt_type='t2i', 86 | random_seed=42, 87 | got_input=None, 88 | only_return_got=False, 89 | **generate_kwargs 90 | ): 91 | """ 92 | Generate text and optional images from the model. 93 | 94 | Args: 95 | text_input (str): The input text prompt. 96 | image (PIL.Image.Image, optional): A single image for Qwen2.5-VL context or editing. 97 | max_new_tokens (int): Maximum number of tokens to generate. 98 | num_inference_steps (int): Diffusion steps for stable diffusion. 99 | guidance_scale (float): CFG scale for stable diffusion. 100 | image_guidance_scale (float): Image guidance scale for stable diffusion. 101 | cond_image_guidance_scale (float): Conditional image guidance scale for stable diffusion. 102 | height (int): Height of the output image. 103 | width (int): Width of the output image. 104 | input_token_num (int): Number of image tokens in the input. 105 | do_classifier_free_guidance (bool): Whether to use classifier-free guidance during inference. 106 | crops_coords_top_left (Tuple[int, int]): The top-left coordinates of the crops. 107 | prompt_type (str): The prompt type to use. 108 | random_seed (int): Random seed for torch.random. 109 | got_input (Str): The customize got content. For interactive generation only. 110 | only_return_got (bool): Whether to return the got text for interactive generation. 111 | generate_kwargs: Additional kwargs for self.mllm.generate(). 112 | 113 | Returns: 114 | A dict with: 115 | 'text': str, the generated text. 116 | 'images': List[PIL.Image.Image], the generated images if any. 117 | """ 118 | device = next(self.parameters()).device 119 | vae_dtype = next(self.vae.parameters()).dtype 120 | 121 | if self.diffusion_transform is None: 122 | self.diffusion_transform = transforms.Compose([ 123 | transforms.Resize((height, width), interpolation=transforms.InterpolationMode.BICUBIC), 124 | transforms.ToTensor(), 125 | transforms.Normalize([0.5], [0.5]) 126 | ]) 127 | if self.source_transform is None: 128 | self.source_transform = transforms.Resize((448, 448), interpolation=transforms.InterpolationMode.BICUBIC) 129 | 130 | # Generate image tokens 131 | img_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_TOKEN.format(i)) for i in 132 | range(self.num_img_out_tokens)] 133 | img_token_ids = torch.tensor(img_token_ids, device=device).unsqueeze(0) # [1, num_img_out_tokens] 134 | 135 | # input image tokens 136 | input_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_PAD_TOKEN) for _ in 137 | range(input_token_num)] 138 | input_token_ids = torch.tensor(input_token_ids, device=device).unsqueeze(0) # [1, num_img_out_tokens] 139 | 140 | # Convert BOI_TOKEN to ID 141 | boi_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOI_TOKEN) 142 | eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN) 143 | bov_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOV_TOKEN) 144 | 145 | # Define stopping criteria to stop at BOI_TOKEN 146 | stopping_criteria = StoppingCriteriaList([ 147 | StopOnToken(boi_token_id), StopOnToken(bov_token_id), StopOnToken(eos_token_id) 148 | ]) 149 | ori_w, ori_h = image.size if image is not None else (width, height) 150 | input_images = [self.source_transform(image)] if image is not None else [] 151 | original_images = [image] if image is not None else [] 152 | generated_images = [] 153 | output_text = '' 154 | 155 | if prompt_type == 't2i': 156 | prompt = f"Follow the caption to generate an image through a chain of thought process: {text_input}" 157 | elif prompt_type == 'edit': 158 | prompt = f"Follow the instruction to edit the given image through a chain of thought process: {text_input}" 159 | else: 160 | raise ValueError(f"Unknown prompt type {prompt_type}") 161 | 162 | # Prepare the conversation structure for Qwen2.5-VL 163 | messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] 164 | 165 | # If image is provided, add it to messages 166 | if image is not None: 167 | # Insert the image into the content 168 | messages[0]["content"].insert(0, {"type": "image"}) 169 | 170 | # Apply chat template to form the prompt as Qwen2.5-VL expects 171 | text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 172 | inputs = self.processor( 173 | text=[text], 174 | images=None if not input_images else input_images, 175 | padding=False, 176 | return_tensors="pt" 177 | ).to(device) 178 | input_ids = inputs.input_ids # shape: [1, seq_len] 179 | 180 | # if the last token is not EOS_TOKEN, continue generating 181 | while input_ids[0, -1] != eos_token_id: 182 | input_length = input_ids.shape[1] 183 | image_inputs = None if not input_images \ 184 | else self.processor.image_processor(images=input_images, return_tensors="pt").to(device) 185 | 186 | if got_input is None: 187 | partial_generation = self.mllm.generate( 188 | input_ids=input_ids, 189 | attention_mask=torch.ones_like(input_ids), 190 | pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None, 191 | image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None, 192 | max_new_tokens=max_new_tokens, 193 | return_dict_in_generate=True, 194 | output_hidden_states=False, # No need yet, we will do a second pass 195 | stopping_criteria=stopping_criteria, 196 | **generate_kwargs 197 | ) 198 | 199 | input_ids = partial_generation['sequences'] # shape: [1, seq_len] 200 | else: 201 | input_ids = self.processor.tokenizer.encode(got_input) 202 | input_ids = torch.tensor(input_ids).unsqueeze(0).to(device) 203 | got_input = None 204 | 205 | if only_return_got: 206 | return {"got_text": self.processor.tokenizer.decode(input_ids[0])} 207 | 208 | # Decode the newly generated text 209 | cur_decoded_text = self.processor.tokenizer.decode(input_ids[0, input_length:], skip_special_tokens=False) 210 | output_text += cur_decoded_text\ 211 | .replace(EOS_TOKEN, '').replace(EOI_TOKEN, '').replace(BOV_TOKEN, '').replace(EOV_TOKEN, '') 212 | 213 | # generate a image 214 | if input_ids[0, -1] == boi_token_id: 215 | input_ids = torch.cat([input_ids, img_token_ids], dim=1) # now includes BOI_TOKEN + image tokens 216 | 217 | second_out = self.mllm( 218 | input_ids=input_ids, 219 | attention_mask=torch.ones_like(input_ids), 220 | pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None, 221 | image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None, 222 | output_hidden_states=True, 223 | return_dict=True 224 | ) 225 | last_hidden_states = second_out['hidden_states'][-1] # [batch_size, seq_len, hidden_size] 226 | 227 | img_gen_mask = torch.logical_and( 228 | self.img_gen_start_id <= input_ids, input_ids < self.img_gen_start_id + self.num_img_out_tokens) 229 | 230 | gen_hidden_states = last_hidden_states[img_gen_mask].view(-1, self.num_img_out_tokens, 231 | last_hidden_states.shape[-1]) 232 | gen_hidden_states = gen_hidden_states[-1:] # only take the last batch 64 image tokens 233 | gen_hidden_states = gen_hidden_states.to(self.output_projector.projector.weight.dtype) 234 | 235 | gen_conditioning = self.output_projector(gen_hidden_states) 236 | gen_conditioning_add = self.output_projector_add(gen_hidden_states) # [bz, gen_num, dim] 237 | null_conditioning = self.output_projector(torch.zeros_like(gen_hidden_states)) 238 | gen_conditioning_pooled = torch.mean(gen_conditioning_add, dim=1) 239 | 240 | self.scheduler.set_timesteps(num_inference_steps, device=device) 241 | timesteps = self.scheduler.timesteps 242 | 243 | # Prepare stable diffusion latents 244 | generator = torch.Generator(device=device).manual_seed(random_seed) 245 | 246 | latents = randn_tensor( 247 | shape=(1, self.vae.config.latent_channels, height // 8, width // 8), 248 | generator=generator, 249 | device=device, 250 | dtype=vae_dtype 251 | ) 252 | latents = latents * self.scheduler.init_noise_sigma 253 | 254 | # The first 4 are the noisy latents, the next 4 are original image latents (for editing). 255 | # In tex-to-image generation scenario, we just provide zeros for original_image. 256 | original_image = original_images[-1] if original_images \ 257 | else Image.new('RGB', (width, height), (0, 0, 0)) 258 | 259 | original_image_tensor = self.diffusion_transform(original_image).unsqueeze(0).to(device).to(vae_dtype) 260 | image_latents = self.vae.encode(original_image_tensor).latent_dist.mode() 261 | 262 | positions_colors = parse_coordinates_colors(cur_decoded_text) 263 | mask_num = max(len(positions_colors), 1) 264 | 265 | cond_images = [Image.new('RGB', (width, height), (0, 0, 0)) for _ in range(mask_num)] 266 | 267 | for i in range(len(positions_colors)): 268 | p_c = positions_colors[i] 269 | draw = ImageDraw.Draw(cond_images[i]) 270 | position = p_c['position'] 271 | color = p_c['color'] 272 | draw.rectangle(((position[0][0] / 1000 * width, position[0][1] / 1000 * height), 273 | (position[1][0] / 1000 * width, position[1][1] / 1000 * height)), fill=color) 274 | del draw 275 | 276 | cond_images_tensor = [] 277 | for c_image in cond_images: 278 | c_image_tensor = self.diffusion_transform(c_image) 279 | cond_images_tensor.append(c_image_tensor) 280 | 281 | # (1, mask_num, 3, target_size, target_size) 282 | cond_mask = torch.stack(cond_images_tensor, dim=0).unsqueeze(0) 283 | B, N, C, H, W = cond_mask.shape 284 | cond_mask = cond_mask.view(B * N, C, H, W) 285 | 286 | unet_cond_embeds = [] 287 | for i in range(0, cond_mask.shape[0], self.vae_batch): 288 | sub_batch = cond_mask[i: i + self.vae_batch] 289 | embeds = self.vae.encode(sub_batch.to(device, dtype=vae_dtype)).latent_dist.mode() 290 | embeds = embeds.to(device) 291 | unet_cond_embeds.append(embeds) 292 | unet_cond_embeds = torch.cat(unet_cond_embeds, dim=0) 293 | unet_cond_embed = unet_cond_embeds.mean(dim=0, keepdim=True) 294 | 295 | if do_classifier_free_guidance: 296 | uncond_image_latents = torch.zeros_like(image_latents) 297 | image_latents = torch.cat([image_latents, image_latents, image_latents, uncond_image_latents], 298 | dim=0) 299 | 300 | uncond_cond_image_latents = torch.zeros_like(unet_cond_embed) 301 | unet_cond_embed = torch.cat([unet_cond_embed, uncond_cond_image_latents, 302 | uncond_cond_image_latents, uncond_cond_image_latents], dim=0) 303 | 304 | combined_prompt_embeds = torch.cat( 305 | [gen_conditioning, gen_conditioning, null_conditioning, null_conditioning], 306 | dim=0) if do_classifier_free_guidance else gen_conditioning 307 | 308 | text_encoder_projection_dim = int(gen_conditioning_pooled.shape[-1]) 309 | 310 | original_size = (height, width) 311 | target_size = (height, width) 312 | 313 | add_time_ids = self._get_add_time_ids( 314 | original_size, 315 | crops_coords_top_left, 316 | target_size, 317 | dtype=combined_prompt_embeds.dtype, 318 | text_encoder_projection_dim=text_encoder_projection_dim, 319 | ) 320 | 321 | added_cond_kwargs = {"text_embeds": gen_conditioning_pooled.to(device), 322 | "time_ids": add_time_ids.to(device)} 323 | 324 | for i, t in enumerate(tqdm(timesteps)): 325 | latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents 326 | scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 327 | scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents, unet_cond_embed], 328 | dim=1) 329 | 330 | noise_pred = self.unet( 331 | scaled_latent_model_input, 332 | t, 333 | encoder_hidden_states=combined_prompt_embeds, 334 | added_cond_kwargs=added_cond_kwargs, 335 | return_dict=False 336 | )[0] 337 | 338 | if do_classifier_free_guidance: 339 | noise_pred_cond, noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(4, 340 | dim=0) 341 | noise_pred = ( 342 | noise_pred_uncond 343 | + guidance_scale * (noise_pred_text - noise_pred_image) 344 | + cond_image_guidance_scale * (noise_pred_cond - noise_pred_text) 345 | + image_guidance_scale * (noise_pred_image - noise_pred_uncond) 346 | ) 347 | 348 | # step through scheduler 349 | latents = self.scheduler.step(noise_pred, t, latents, generator=generator, return_dict=False)[0] 350 | 351 | final_latents = latents / self.vae.config.scaling_factor 352 | image_tensor = self.vae.decode(final_latents, generator=generator).sample 353 | image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1) 354 | pil_image = Image.fromarray( 355 | (image_tensor[0].permute(1, 2, 0).cpu().float().numpy() * 255).astype("uint8")) 356 | 357 | generated_images.append(pil_image) 358 | original_images.append(pil_image) 359 | elif input_ids[0, -1] == bov_token_id: 360 | input_images.append(self.source_transform(generated_images[-1])) 361 | input_ids = torch.cat([input_ids, input_token_ids], dim=1) 362 | 363 | # resize generated images with ori_w, and ori_h, with the shortest side being 1024 364 | if ori_w < ori_h: 365 | target_size = (width, int(height * ori_h / ori_w)) 366 | else: 367 | target_size = (int(width * ori_w / ori_h), height) 368 | generated_images = [img.resize(target_size) for img in generated_images] 369 | 370 | return {"got_text": output_text, "images": generated_images} 371 | 372 | @classmethod 373 | def from_pretrained(cls, mllm, output_projector, scheduler, vae, unet, pretrained_model_path=None, **kwargs): 374 | model = cls(mllm=mllm, output_projector=output_projector, scheduler=scheduler, vae=vae, unet=unet, **kwargs) 375 | if os.environ.get('DEBUG_FLAG', 'False') == 'True': 376 | return model 377 | 378 | if pretrained_model_path is not None: 379 | ckpt = torch.load(pretrained_model_path, map_location='cpu') 380 | logs = model.load_state_dict(ckpt, strict=False) 381 | print(logs) 382 | return model 383 | -------------------------------------------------------------------------------- /got/models/peft_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import DictConfig 3 | import hydra 4 | from peft import ( 5 | LoraConfig, 6 | PeftModel, 7 | LoraModel, 8 | PeftModelForCausalLM, 9 | get_peft_model, 10 | ) 11 | 12 | 13 | def get_peft_model_without_resize_embedding(model, peft_config=None, torch_dtype='bf16'): 14 | if torch_dtype == 'bf16' or torch_dtype == 'bfloat16': 15 | torch_dtype = torch.bfloat16 16 | elif torch_dtype == 'fp16' or torch_dtype == 'float16': 17 | torch_dtype = torch.float16 18 | else: 19 | torch_dtype = torch.float32 20 | 21 | if isinstance(model, DictConfig): 22 | model = hydra.utils.instantiate(model, torch_dtype=torch_dtype) 23 | 24 | print('peft config: ', peft_config) 25 | if isinstance(peft_config, DictConfig): 26 | peft_config = hydra.utils.instantiate(peft_config) 27 | peft_model = get_peft_model(model=model, peft_config=peft_config) 28 | 29 | # peft_model.print_trainable_parameters() 30 | 31 | return peft_model 32 | -------------------------------------------------------------------------------- /got/models/projector.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class LinearProjector(nn.Module): 5 | def __init__(self, in_hidden_size, out_hidden_size, bias=True): 6 | super().__init__() 7 | self.projector = nn.Linear(in_hidden_size, out_hidden_size, bias=bias) 8 | 9 | def forward(self, feature): 10 | return self.projector(feature) 11 | -------------------------------------------------------------------------------- /got/models/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from transformers import StoppingCriteria 4 | 5 | 6 | BOI_TOKEN = '<|im_gen_start|>' 7 | EOI_TOKEN = '<|im_gen_end|>' 8 | IMG_TOKEN = '<|im_gen_{:04d}|>' 9 | EOS_TOKEN = '<|endoftext|>' 10 | BOV_TOKEN = '<|vision_start|>' 11 | EOV_TOKEN = '<|vision_end|>' 12 | IMG_PAD_TOKEN = '<|image_pad|>' 13 | 14 | 15 | def remove_mismatched_weights(model, pretrained_state_dict): 16 | own_state = model.state_dict() 17 | mismatch_keys = [] 18 | 19 | for name in list(pretrained_state_dict.keys()): 20 | if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape: 21 | mismatch_keys.append(name) 22 | pretrained_state_dict.pop(name) 23 | 24 | return pretrained_state_dict, mismatch_keys 25 | 26 | 27 | def parse_coordinates_colors(cot_text): 28 | """ 29 | Parse bounding box coordinates and their colors from the CoT text. 30 | 31 | Args: 32 | cot_text (str): Chain of Thought text containing bounding box information. 33 | 34 | Returns: 35 | list: A list of dictionaries with keys 'x1', 'y1', 'x2', 'y2', and 'color'. 36 | """ 37 | # Regular expression to match bounding box and color patterns 38 | pattern = r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|> \((\w+)\)" 39 | 40 | # Parse all matches 41 | matches = re.findall(pattern, cot_text) 42 | 43 | # Extract bounding box coordinates and colors 44 | parsed_data = [] 45 | for match in matches: 46 | x1, y1, x2, y2, color = match 47 | parsed_data.append({ 48 | 'position': [[int(x1), int(y1)], [int(x2), int(y2)]], 49 | 'color': color 50 | }) 51 | 52 | return parsed_data 53 | 54 | 55 | class StopOnToken(StoppingCriteria): 56 | def __init__(self, token_id): 57 | self.token_id = token_id 58 | 59 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 60 | # Check if the last generated token is BOI_TOKEN 61 | return input_ids[0, -1] == self.token_id 62 | -------------------------------------------------------------------------------- /got/processer/qwen25_vl_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor 2 | 3 | 4 | BOI_TOKEN = '<|im_gen_start|>' 5 | EOI_TOKEN = '<|im_gen_end|>' 6 | IMG_TOKEN = '<|im_gen_{:04d}|>' 7 | 8 | 9 | def get_processor(model_name, add_gen_token_num=64): 10 | processor = AutoProcessor.from_pretrained(model_name) 11 | add_token_list = [BOI_TOKEN, EOI_TOKEN] 12 | for i in range(add_gen_token_num): 13 | add_token_list.append(IMG_TOKEN.format(i)) 14 | processor.tokenizer.add_tokens(add_token_list, special_tokens=True) 15 | return processor 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | hydra-core 4 | omegaconf 5 | transformers==4.49.0 6 | diffusers==0.29.0 7 | sentencepiece 8 | opencv-python 9 | peft==0.13.2 10 | pyrootutils 11 | xformers==0.0.22 12 | accelerate==1.3.0 13 | transformers_stream_generator 14 | tqdm 15 | notebook 16 | numpy==1.21.2 17 | huggingface_hub==0.29.3 --------------------------------------------------------------------------------