├── .gitignore
├── LICENSE
├── README.md
├── configs
├── clm_models
│ ├── agent_got.yaml
│ └── llm_qwen25_vl_3b_lora.yaml
└── tokenizer
│ └── qwen25_vl_tokenizer_token64.yaml
├── examples
├── hat.jpg
├── strawberry.jpg
└── tool.png
├── figures
├── architecture.jpg
├── interactive.jpg
└── teaser.jpg
├── got
├── __init__.py
├── models
│ ├── __init__.py
│ ├── got_model.py
│ ├── peft_models.py
│ ├── projector.py
│ └── utils.py
└── processer
│ └── qwen25_vl_processor.py
├── inference.ipynb
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | pretrained/
2 | .DS_Store
3 | .idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Rongyao Fang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing
2 |
3 |

4 |

5 |
6 | [Rongyao Fang](https://scholar.google.com/citations?user=FtH3CW4AAAAJ&hl=en)
1\*, [Chengqi Duan](https://scholar.google.com/citations?user=r9qb4ZwAAAAJ&hl=zh-CN)
2\*, [Kun Wang]()
3, [Linjiang Huang](https://leonhlj.github.io/)
6, [Hao Li](https://scholar.google.com/citations?user=qHqQsY4AAAAJ&hl=zh-CN)
1,4, [Shilin Yan](https://scholar.google.com/citations?user=2VhjOykAAAAJ&hl=zh-CN), [Hao Tian]()
3, [Xingyu Zeng]()
3, [Rui Zhao]()
3, [Jifeng Dai](https://jifengdai.org/)
4,5, [Xihui Liu](https://xh-liu.github.io/)
2 :envelope:, [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)
1 :envelope:
7 |
8 |
1CUHK MMLab,
2HKU MMLab,
3SenseTime,
4Shanghai AI Laboratory,
5Tsinghua University,
6Beihang University
9 |
10 | *Equal contribution, :envelope:Corresponding authors
11 |
12 |
13 |
14 |

15 |
16 |
17 |
26 |
27 | ## Introduction
28 |
29 | We present **Generation Chain-of-Thought (GoT)**, a novel paradigm that enables generation and editing through an explicit language reasoning process before outputting images. This approach transforms conventional text-to-image generation and editing into a reasoning-guided framework that analyzes semantic relationships and spatial arrangements.
30 |
31 | GoT pioneers a new direction for reasoning-driven visual generation and editing, producing images that better align with human intent through:
32 |
33 | - **Semantic-Spatial Reasoning**: Integrates both semantic understanding and explicit spatial coordinates
34 | - **Unified Framework**: Handles both image generation and editing with the same architecture
35 |
36 | ## Released Datasets
37 |
38 | | Dataset | Link | Amount |
39 | |---------|------|--------|
40 | | **Laion-Aesthetics-High-Resolution-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/Laion-Aesthetics-High-Resolution-GoT) | 3.77M |
41 | | **JourneyDB-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/JourneyDB-GoT) | 4.09M |
42 | | **OmniEdit-GoT** | [🤗 HuggingFace](https://huggingface.co/datasets/LucasFang/OmniEdit-GoT) | 736K |
43 |
44 | ## Dataset Features
45 |
46 | ### Laion-Aesthetics-High-Resolution-GoT
47 | - 3.77 million High-quality images filtered for sizes larger than 512 pixels from Laion-Aesthetics
48 | - Prompts and GoT descriptions from Qwen2-VL
49 | - Prompts averaging 110.81 characters
50 | - GoT descriptions averaging 811.56 characters
51 | - 3.78 bounding boxes per image on average
52 |
53 | ### JourneyDB-GoT
54 | - 4.09 million high-quality AI-generated images
55 | - Prompts and GoT descriptions from Qwen2-VL
56 | - Prompts averaging 149.78 characters
57 | - GoT descriptions averaging 906.01 characters
58 | - 4.09 bounding boxes per image on average
59 | - Please download the images from [JourneyDB dataset](https://opendatalab.com/OpenDataLab/JourneyDB/tree/main/raw/JourneyDB/train/imgs)
60 |
61 | ### OmniEdit-GoT
62 | - 736K high-quality image editing samples from OmniEdit
63 | - Diverse editing operations (addition, removal, swap, attribute changes, style transfer)
64 | - Detailed reasoning chains with step-by-step editing processes
65 | - Precise spatial coordinate annotations for editing regions
66 | - Please download the images from [OmniEdit dataset](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M)
67 |
68 | ## Released Model: GoT Framework
69 |
70 | | Model | Link | Architecture |
71 | |------------|------|----------------------|
72 | | **GoT-6B** | [🤗 HuggingFace](https://huggingface.co/LucasFang/GoT-6B) | Qwen2.5-VL-3B + SDXL |
73 |
74 | ## Model Features
75 |
76 |
77 |

78 |
79 |
80 | Our GoT framework consists of two key components:
81 |
82 | 1. **Semantic-Spatial MLLM**: Generates detailed reasoning chains with spatial information using Qwen2.5-VL as the backbone
83 | 2. **SSGM Diffusion Module**: Leverages the semantic guidance, spatial layouts, and reference images to create high-quality visual outputs
84 |
85 | The Semantic-Spatial Guidance Module (SSGM) combines three guidance pathways:
86 | - **Semantic Guidance**: Captures relationships and attributes
87 | - **Spatial Guidance**: Controls precise object placement
88 | - **Reference Guidance**: Provides context for editing tasks
89 |
90 | ## Results
91 |
92 | ### Text-to-Image Generation
93 |
94 | GoT achieves state-of-the-art performance on the GenEval benchmark, particularly excelling in composition tasks:
95 |
96 |
97 |
98 | | Method | Architecture | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Attr. Binding |
99 | |--------|--------------|---------|-------------|----------|----------|--------|----------|---------------|
100 | | SD-XL | Unet+CLIP | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 |
101 | | SD3 | MMDIT+CLIP+T5 | 0.62 | 0.98 | 0.74 | 0.63 | 0.67 | 0.34 | 0.36 |
102 | | Emu3-Gen | Autoregressive | 0.54 | 0.98 | 0.71 | 0.34 | 0.81 | 0.17 | 0.21 |
103 | | Janus | Autoregressive | 0.61 | 0.97 | 0.68 | 0.30 | 0.84 | 0.46 | 0.42 |
104 | | JanusFlow | Autoregressive | 0.63 | 0.97 | 0.59 | 0.45 | 0.83 | 0.53 | 0.42 |
105 | | **GoT Framework** | Unet+Qwen2.5-VL | **0.64** | **0.99** | 0.69 | **0.67** | **0.85** | 0.34 | 0.27 |
106 |
107 |
108 |
109 | ### Image Editing
110 |
111 | Our approach also demonstrates superior performance on image editing benchmarks:
112 |
113 |
114 |
115 | | Method | Emu-Edit | | ImagenHub | Reason-Edit |
116 | |--------|----------|--------|-----------|------------|
117 | | | CLIP-I | CLIP-T | GPT-4o Eval. | GPT-4o Eval. |
118 | | IP2P | 0.834 | 0.219 | 0.308 | 0.286 |
119 | | MagicBrush | 0.838 | 0.222 | 0.513 | 0.334 |
120 | | SEED-X | 0.825 | 0.272 | 0.166 | 0.239 |
121 | | CosXL-Edit | 0.860 | 0.274 | 0.464 | 0.325 |
122 | | **GoT Framework** | **0.864** | **0.276** | **0.533** | 0.561 |
123 |
124 |
125 |
126 | ## Usage
127 |
128 | ### Dependencies
129 | - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
130 | - [PyTorch >=2.0.1](https://pytorch.org/)
131 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
132 |
133 | ### Installation
134 | Clone the repo and install dependent packages
135 |
136 | ```bash
137 | git clone git@github.com:rongyaofang/GoT.git
138 | cd GoT
139 | pip install -r requirements.txt
140 | ```
141 |
142 | ### Model Weights
143 | Place the required model weights in the `./pretrained` directory as follows:
144 |
145 | 1. GoT-6B model weights
146 | 2. [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
147 | 3. [Stable Diffusion XL Base 1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
148 |
149 | Your directory structure should match the following:
150 |
151 | ```
152 | GoT
153 | ├── pretrained
154 | │ ├── GoT-6B
155 | │ ├── Qwen2.5-VL-3B-Instruct
156 | │ └── stable-diffusion-xl-base-1.0
157 | ├── ...
158 | ```
159 |
160 | ### Inference
161 | Follow the instructions in the [inference notebook](https://github.com/rongyaofang/GoT/blob/main/inference.ipynb)
162 |
163 | ## License
164 |
165 | This code is released under the MIT License.
166 |
167 | ## Citation
168 |
169 | If you find this work helpful, please consider citing:
170 |
171 | ```
172 | @article{fang2025got,
173 | title={GoT: Unleashing Reasoning Capability of Multimodal Large Language Model for Visual Generation and Editing},
174 | author={Fang, Rongyao and Duan, Chengqi and Wang, Kun and Huang, Linjiang and Li, Hao and Yan, Shilin and Tian, Hao and Zeng, Xingyu and Zhao, Rui and Dai, Jifeng and Liu, Xihui and Li, Hongsheng},
175 | journal={arXiv preprint arXiv:2503.10639},
176 | year={2025}
177 | }
178 | ```
179 |
180 | ## Contact
181 |
182 | If you have any questions, please raise an issue or contact us at [rongyaofang@gmail.com](mailto:rongyaofang@gmail.com).
183 |
--------------------------------------------------------------------------------
/configs/clm_models/agent_got.yaml:
--------------------------------------------------------------------------------
1 | _target_: got.models.got_model.GenCot.from_pretrained
2 | output_projector:
3 | _target_: got.models.projector.LinearProjector
4 | in_hidden_size: 2048
5 | out_hidden_size: 2048
6 |
7 | output_projector_add:
8 | _target_: got.models.projector.LinearProjector
9 | in_hidden_size: 2048
10 | out_hidden_size: 1280
11 |
12 | scheduler:
13 | _target_: diffusers.DDPMScheduler.from_pretrained
14 | pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0
15 | subfolder: scheduler
16 |
17 | vae:
18 | _target_: diffusers.AutoencoderKL.from_pretrained
19 | pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0
20 | subfolder: vae
21 |
22 | unet:
23 | _target_: diffusers.UNet2DConditionModel.from_pretrained
24 | pretrained_model_name_or_path: pretrained/stable-diffusion-xl-base-1.0
25 | subfolder: unet
26 |
27 | processor:
28 | _target_: got.processer.qwen25_vl_processor.get_processor
29 | model_name: pretrained/Qwen2.5-VL-3B-Instruct
30 | add_gen_token_num: 64
31 |
32 | num_img_out_tokens: 64
33 | img_gen_start_id: 151667
34 |
--------------------------------------------------------------------------------
/configs/clm_models/llm_qwen25_vl_3b_lora.yaml:
--------------------------------------------------------------------------------
1 | _target_: got.models.peft_models.get_peft_model_without_resize_embedding
2 | model:
3 | _target_: transformers.Qwen2_5_VLForConditionalGeneration.from_pretrained
4 | pretrained_model_name_or_path: pretrained/Qwen2.5-VL-3B-Instruct
5 | peft_config:
6 | _target_: peft.LoraConfig
7 | _convert_: object
8 | r: 32
9 | lora_alpha: 32
10 | lora_dropout: 0.05
11 | target_modules:
12 | - q_proj
13 | - v_proj
14 | - k_proj
15 | - o_proj
16 | - gate_proj
17 | - down_proj
18 | - up_proj
19 | modules_to_save:
20 | - embed_tokens
21 | - lm_head
22 | - input_layernorm
23 | - post_attention_layernorm
24 | task_type: CAUSAL_LM
25 |
--------------------------------------------------------------------------------
/configs/tokenizer/qwen25_vl_tokenizer_token64.yaml:
--------------------------------------------------------------------------------
1 | _target_: got.processer.qwen25_vl_processor.get_processor
2 | model_name: pretrained/Qwen2.5-VL-3B-Instruct
3 | add_gen_token_num: 64
--------------------------------------------------------------------------------
/examples/hat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/examples/hat.jpg
--------------------------------------------------------------------------------
/examples/strawberry.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/examples/strawberry.jpg
--------------------------------------------------------------------------------
/examples/tool.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/examples/tool.png
--------------------------------------------------------------------------------
/figures/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/figures/architecture.jpg
--------------------------------------------------------------------------------
/figures/interactive.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/figures/interactive.jpg
--------------------------------------------------------------------------------
/figures/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/figures/teaser.jpg
--------------------------------------------------------------------------------
/got/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/got/__init__.py
--------------------------------------------------------------------------------
/got/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rongyaofang/GoT/7ded721eebf393c8855025c70fb253c282f17d00/got/models/__init__.py
--------------------------------------------------------------------------------
/got/models/got_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from PIL import Image, ImageDraw
5 | from torchvision import transforms
6 | from transformers import StoppingCriteriaList
7 | from diffusers.utils.import_utils import is_xformers_available
8 | from diffusers.utils.torch_utils import randn_tensor
9 | from tqdm import tqdm
10 | from .utils import (
11 | IMG_TOKEN, BOI_TOKEN, EOI_TOKEN, EOS_TOKEN, BOV_TOKEN, EOV_TOKEN, IMG_PAD_TOKEN,
12 | parse_coordinates_colors, StopOnToken
13 | )
14 |
15 |
16 | class GenCot(nn.Module):
17 | def __init__(self, mllm, output_projector, output_projector_add, scheduler, vae, unet, processor,
18 | num_img_out_tokens=64, img_gen_start_id=151667, box_start_id=151648, box_end_id=151649) -> None:
19 | super().__init__()
20 | self.mllm = mllm # qwen25-vl model
21 | self.output_projector = output_projector
22 | self.vae = vae
23 | self.unet = unet
24 | self.scheduler = scheduler
25 | self.output_projector_add = output_projector_add
26 |
27 | # uses an additional image for conditioning.
28 | # it uses 12 channels (instead of 4) in the first (conv) layer of the UNet.
29 | in_channels = 12
30 | self.unet.register_to_config(in_channels=in_channels)
31 |
32 | with torch.no_grad():
33 | conv = torch.nn.Conv2d(in_channels, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size,
34 | self.unet.conv_in.stride, self.unet.conv_in.padding)
35 | conv.weight.zero_()
36 | conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
37 | self.unet.conv_in = conv
38 | self.vae.requires_grad_(False)
39 | self.vae_batch = 1
40 |
41 | if is_xformers_available():
42 | import xformers
43 | unet.enable_xformers_memory_efficient_attention()
44 |
45 | self.img_gen_start_id = img_gen_start_id
46 | self.num_img_out_tokens = num_img_out_tokens
47 | self.box_start_id = box_start_id
48 | self.box_end_id = box_end_id
49 | self.diffusion_transform = None
50 | self.source_transform = None
51 | self.processor = processor
52 |
53 | def _get_add_time_ids(
54 | self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
55 | ):
56 | add_time_ids = list(original_size + crops_coords_top_left + target_size)
57 |
58 | passed_add_embed_dim = (
59 | self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
60 | )
61 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
62 |
63 | if expected_add_embed_dim != passed_add_embed_dim:
64 | raise ValueError(
65 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
66 | )
67 |
68 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
69 | return add_time_ids
70 |
71 | @torch.no_grad()
72 | def generate(self,
73 | text_input,
74 | image=None,
75 | max_new_tokens=1024,
76 | num_inference_steps=50,
77 | guidance_scale=7.5,
78 | image_guidance_scale=1.0,
79 | cond_image_guidance_scale=4.0,
80 | height=1024,
81 | width=1024,
82 | input_token_num=256,
83 | do_classifier_free_guidance=True,
84 | crops_coords_top_left=(0, 0),
85 | prompt_type='t2i',
86 | random_seed=42,
87 | got_input=None,
88 | only_return_got=False,
89 | **generate_kwargs
90 | ):
91 | """
92 | Generate text and optional images from the model.
93 |
94 | Args:
95 | text_input (str): The input text prompt.
96 | image (PIL.Image.Image, optional): A single image for Qwen2.5-VL context or editing.
97 | max_new_tokens (int): Maximum number of tokens to generate.
98 | num_inference_steps (int): Diffusion steps for stable diffusion.
99 | guidance_scale (float): CFG scale for stable diffusion.
100 | image_guidance_scale (float): Image guidance scale for stable diffusion.
101 | cond_image_guidance_scale (float): Conditional image guidance scale for stable diffusion.
102 | height (int): Height of the output image.
103 | width (int): Width of the output image.
104 | input_token_num (int): Number of image tokens in the input.
105 | do_classifier_free_guidance (bool): Whether to use classifier-free guidance during inference.
106 | crops_coords_top_left (Tuple[int, int]): The top-left coordinates of the crops.
107 | prompt_type (str): The prompt type to use.
108 | random_seed (int): Random seed for torch.random.
109 | got_input (Str): The customize got content. For interactive generation only.
110 | only_return_got (bool): Whether to return the got text for interactive generation.
111 | generate_kwargs: Additional kwargs for self.mllm.generate().
112 |
113 | Returns:
114 | A dict with:
115 | 'text': str, the generated text.
116 | 'images': List[PIL.Image.Image], the generated images if any.
117 | """
118 | device = next(self.parameters()).device
119 | vae_dtype = next(self.vae.parameters()).dtype
120 |
121 | if self.diffusion_transform is None:
122 | self.diffusion_transform = transforms.Compose([
123 | transforms.Resize((height, width), interpolation=transforms.InterpolationMode.BICUBIC),
124 | transforms.ToTensor(),
125 | transforms.Normalize([0.5], [0.5])
126 | ])
127 | if self.source_transform is None:
128 | self.source_transform = transforms.Resize((448, 448), interpolation=transforms.InterpolationMode.BICUBIC)
129 |
130 | # Generate image tokens
131 | img_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_TOKEN.format(i)) for i in
132 | range(self.num_img_out_tokens)]
133 | img_token_ids = torch.tensor(img_token_ids, device=device).unsqueeze(0) # [1, num_img_out_tokens]
134 |
135 | # input image tokens
136 | input_token_ids = [self.processor.tokenizer.convert_tokens_to_ids(IMG_PAD_TOKEN) for _ in
137 | range(input_token_num)]
138 | input_token_ids = torch.tensor(input_token_ids, device=device).unsqueeze(0) # [1, num_img_out_tokens]
139 |
140 | # Convert BOI_TOKEN to ID
141 | boi_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOI_TOKEN)
142 | eos_token_id = self.processor.tokenizer.convert_tokens_to_ids(EOS_TOKEN)
143 | bov_token_id = self.processor.tokenizer.convert_tokens_to_ids(BOV_TOKEN)
144 |
145 | # Define stopping criteria to stop at BOI_TOKEN
146 | stopping_criteria = StoppingCriteriaList([
147 | StopOnToken(boi_token_id), StopOnToken(bov_token_id), StopOnToken(eos_token_id)
148 | ])
149 | ori_w, ori_h = image.size if image is not None else (width, height)
150 | input_images = [self.source_transform(image)] if image is not None else []
151 | original_images = [image] if image is not None else []
152 | generated_images = []
153 | output_text = ''
154 |
155 | if prompt_type == 't2i':
156 | prompt = f"Follow the caption to generate an image through a chain of thought process: {text_input}"
157 | elif prompt_type == 'edit':
158 | prompt = f"Follow the instruction to edit the given image through a chain of thought process: {text_input}"
159 | else:
160 | raise ValueError(f"Unknown prompt type {prompt_type}")
161 |
162 | # Prepare the conversation structure for Qwen2.5-VL
163 | messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
164 |
165 | # If image is provided, add it to messages
166 | if image is not None:
167 | # Insert the image into the content
168 | messages[0]["content"].insert(0, {"type": "image"})
169 |
170 | # Apply chat template to form the prompt as Qwen2.5-VL expects
171 | text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
172 | inputs = self.processor(
173 | text=[text],
174 | images=None if not input_images else input_images,
175 | padding=False,
176 | return_tensors="pt"
177 | ).to(device)
178 | input_ids = inputs.input_ids # shape: [1, seq_len]
179 |
180 | # if the last token is not EOS_TOKEN, continue generating
181 | while input_ids[0, -1] != eos_token_id:
182 | input_length = input_ids.shape[1]
183 | image_inputs = None if not input_images \
184 | else self.processor.image_processor(images=input_images, return_tensors="pt").to(device)
185 |
186 | if got_input is None:
187 | partial_generation = self.mllm.generate(
188 | input_ids=input_ids,
189 | attention_mask=torch.ones_like(input_ids),
190 | pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None,
191 | image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None,
192 | max_new_tokens=max_new_tokens,
193 | return_dict_in_generate=True,
194 | output_hidden_states=False, # No need yet, we will do a second pass
195 | stopping_criteria=stopping_criteria,
196 | **generate_kwargs
197 | )
198 |
199 | input_ids = partial_generation['sequences'] # shape: [1, seq_len]
200 | else:
201 | input_ids = self.processor.tokenizer.encode(got_input)
202 | input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
203 | got_input = None
204 |
205 | if only_return_got:
206 | return {"got_text": self.processor.tokenizer.decode(input_ids[0])}
207 |
208 | # Decode the newly generated text
209 | cur_decoded_text = self.processor.tokenizer.decode(input_ids[0, input_length:], skip_special_tokens=False)
210 | output_text += cur_decoded_text\
211 | .replace(EOS_TOKEN, '').replace(EOI_TOKEN, '').replace(BOV_TOKEN, '').replace(EOV_TOKEN, '')
212 |
213 | # generate a image
214 | if input_ids[0, -1] == boi_token_id:
215 | input_ids = torch.cat([input_ids, img_token_ids], dim=1) # now includes BOI_TOKEN + image tokens
216 |
217 | second_out = self.mllm(
218 | input_ids=input_ids,
219 | attention_mask=torch.ones_like(input_ids),
220 | pixel_values=image_inputs.pixel_values if hasattr(image_inputs, 'pixel_values') else None,
221 | image_grid_thw=image_inputs.image_grid_thw if hasattr(image_inputs, 'image_grid_thw') else None,
222 | output_hidden_states=True,
223 | return_dict=True
224 | )
225 | last_hidden_states = second_out['hidden_states'][-1] # [batch_size, seq_len, hidden_size]
226 |
227 | img_gen_mask = torch.logical_and(
228 | self.img_gen_start_id <= input_ids, input_ids < self.img_gen_start_id + self.num_img_out_tokens)
229 |
230 | gen_hidden_states = last_hidden_states[img_gen_mask].view(-1, self.num_img_out_tokens,
231 | last_hidden_states.shape[-1])
232 | gen_hidden_states = gen_hidden_states[-1:] # only take the last batch 64 image tokens
233 | gen_hidden_states = gen_hidden_states.to(self.output_projector.projector.weight.dtype)
234 |
235 | gen_conditioning = self.output_projector(gen_hidden_states)
236 | gen_conditioning_add = self.output_projector_add(gen_hidden_states) # [bz, gen_num, dim]
237 | null_conditioning = self.output_projector(torch.zeros_like(gen_hidden_states))
238 | gen_conditioning_pooled = torch.mean(gen_conditioning_add, dim=1)
239 |
240 | self.scheduler.set_timesteps(num_inference_steps, device=device)
241 | timesteps = self.scheduler.timesteps
242 |
243 | # Prepare stable diffusion latents
244 | generator = torch.Generator(device=device).manual_seed(random_seed)
245 |
246 | latents = randn_tensor(
247 | shape=(1, self.vae.config.latent_channels, height // 8, width // 8),
248 | generator=generator,
249 | device=device,
250 | dtype=vae_dtype
251 | )
252 | latents = latents * self.scheduler.init_noise_sigma
253 |
254 | # The first 4 are the noisy latents, the next 4 are original image latents (for editing).
255 | # In tex-to-image generation scenario, we just provide zeros for original_image.
256 | original_image = original_images[-1] if original_images \
257 | else Image.new('RGB', (width, height), (0, 0, 0))
258 |
259 | original_image_tensor = self.diffusion_transform(original_image).unsqueeze(0).to(device).to(vae_dtype)
260 | image_latents = self.vae.encode(original_image_tensor).latent_dist.mode()
261 |
262 | positions_colors = parse_coordinates_colors(cur_decoded_text)
263 | mask_num = max(len(positions_colors), 1)
264 |
265 | cond_images = [Image.new('RGB', (width, height), (0, 0, 0)) for _ in range(mask_num)]
266 |
267 | for i in range(len(positions_colors)):
268 | p_c = positions_colors[i]
269 | draw = ImageDraw.Draw(cond_images[i])
270 | position = p_c['position']
271 | color = p_c['color']
272 | draw.rectangle(((position[0][0] / 1000 * width, position[0][1] / 1000 * height),
273 | (position[1][0] / 1000 * width, position[1][1] / 1000 * height)), fill=color)
274 | del draw
275 |
276 | cond_images_tensor = []
277 | for c_image in cond_images:
278 | c_image_tensor = self.diffusion_transform(c_image)
279 | cond_images_tensor.append(c_image_tensor)
280 |
281 | # (1, mask_num, 3, target_size, target_size)
282 | cond_mask = torch.stack(cond_images_tensor, dim=0).unsqueeze(0)
283 | B, N, C, H, W = cond_mask.shape
284 | cond_mask = cond_mask.view(B * N, C, H, W)
285 |
286 | unet_cond_embeds = []
287 | for i in range(0, cond_mask.shape[0], self.vae_batch):
288 | sub_batch = cond_mask[i: i + self.vae_batch]
289 | embeds = self.vae.encode(sub_batch.to(device, dtype=vae_dtype)).latent_dist.mode()
290 | embeds = embeds.to(device)
291 | unet_cond_embeds.append(embeds)
292 | unet_cond_embeds = torch.cat(unet_cond_embeds, dim=0)
293 | unet_cond_embed = unet_cond_embeds.mean(dim=0, keepdim=True)
294 |
295 | if do_classifier_free_guidance:
296 | uncond_image_latents = torch.zeros_like(image_latents)
297 | image_latents = torch.cat([image_latents, image_latents, image_latents, uncond_image_latents],
298 | dim=0)
299 |
300 | uncond_cond_image_latents = torch.zeros_like(unet_cond_embed)
301 | unet_cond_embed = torch.cat([unet_cond_embed, uncond_cond_image_latents,
302 | uncond_cond_image_latents, uncond_cond_image_latents], dim=0)
303 |
304 | combined_prompt_embeds = torch.cat(
305 | [gen_conditioning, gen_conditioning, null_conditioning, null_conditioning],
306 | dim=0) if do_classifier_free_guidance else gen_conditioning
307 |
308 | text_encoder_projection_dim = int(gen_conditioning_pooled.shape[-1])
309 |
310 | original_size = (height, width)
311 | target_size = (height, width)
312 |
313 | add_time_ids = self._get_add_time_ids(
314 | original_size,
315 | crops_coords_top_left,
316 | target_size,
317 | dtype=combined_prompt_embeds.dtype,
318 | text_encoder_projection_dim=text_encoder_projection_dim,
319 | )
320 |
321 | added_cond_kwargs = {"text_embeds": gen_conditioning_pooled.to(device),
322 | "time_ids": add_time_ids.to(device)}
323 |
324 | for i, t in enumerate(tqdm(timesteps)):
325 | latent_model_input = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
326 | scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
327 | scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents, unet_cond_embed],
328 | dim=1)
329 |
330 | noise_pred = self.unet(
331 | scaled_latent_model_input,
332 | t,
333 | encoder_hidden_states=combined_prompt_embeds,
334 | added_cond_kwargs=added_cond_kwargs,
335 | return_dict=False
336 | )[0]
337 |
338 | if do_classifier_free_guidance:
339 | noise_pred_cond, noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(4,
340 | dim=0)
341 | noise_pred = (
342 | noise_pred_uncond
343 | + guidance_scale * (noise_pred_text - noise_pred_image)
344 | + cond_image_guidance_scale * (noise_pred_cond - noise_pred_text)
345 | + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
346 | )
347 |
348 | # step through scheduler
349 | latents = self.scheduler.step(noise_pred, t, latents, generator=generator, return_dict=False)[0]
350 |
351 | final_latents = latents / self.vae.config.scaling_factor
352 | image_tensor = self.vae.decode(final_latents, generator=generator).sample
353 | image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1)
354 | pil_image = Image.fromarray(
355 | (image_tensor[0].permute(1, 2, 0).cpu().float().numpy() * 255).astype("uint8"))
356 |
357 | generated_images.append(pil_image)
358 | original_images.append(pil_image)
359 | elif input_ids[0, -1] == bov_token_id:
360 | input_images.append(self.source_transform(generated_images[-1]))
361 | input_ids = torch.cat([input_ids, input_token_ids], dim=1)
362 |
363 | # resize generated images with ori_w, and ori_h, with the shortest side being 1024
364 | if ori_w < ori_h:
365 | target_size = (width, int(height * ori_h / ori_w))
366 | else:
367 | target_size = (int(width * ori_w / ori_h), height)
368 | generated_images = [img.resize(target_size) for img in generated_images]
369 |
370 | return {"got_text": output_text, "images": generated_images}
371 |
372 | @classmethod
373 | def from_pretrained(cls, mllm, output_projector, scheduler, vae, unet, pretrained_model_path=None, **kwargs):
374 | model = cls(mllm=mllm, output_projector=output_projector, scheduler=scheduler, vae=vae, unet=unet, **kwargs)
375 | if os.environ.get('DEBUG_FLAG', 'False') == 'True':
376 | return model
377 |
378 | if pretrained_model_path is not None:
379 | ckpt = torch.load(pretrained_model_path, map_location='cpu')
380 | logs = model.load_state_dict(ckpt, strict=False)
381 | print(logs)
382 | return model
383 |
--------------------------------------------------------------------------------
/got/models/peft_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from omegaconf import DictConfig
3 | import hydra
4 | from peft import (
5 | LoraConfig,
6 | PeftModel,
7 | LoraModel,
8 | PeftModelForCausalLM,
9 | get_peft_model,
10 | )
11 |
12 |
13 | def get_peft_model_without_resize_embedding(model, peft_config=None, torch_dtype='bf16'):
14 | if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
15 | torch_dtype = torch.bfloat16
16 | elif torch_dtype == 'fp16' or torch_dtype == 'float16':
17 | torch_dtype = torch.float16
18 | else:
19 | torch_dtype = torch.float32
20 |
21 | if isinstance(model, DictConfig):
22 | model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
23 |
24 | print('peft config: ', peft_config)
25 | if isinstance(peft_config, DictConfig):
26 | peft_config = hydra.utils.instantiate(peft_config)
27 | peft_model = get_peft_model(model=model, peft_config=peft_config)
28 |
29 | # peft_model.print_trainable_parameters()
30 |
31 | return peft_model
32 |
--------------------------------------------------------------------------------
/got/models/projector.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class LinearProjector(nn.Module):
5 | def __init__(self, in_hidden_size, out_hidden_size, bias=True):
6 | super().__init__()
7 | self.projector = nn.Linear(in_hidden_size, out_hidden_size, bias=bias)
8 |
9 | def forward(self, feature):
10 | return self.projector(feature)
11 |
--------------------------------------------------------------------------------
/got/models/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from transformers import StoppingCriteria
4 |
5 |
6 | BOI_TOKEN = '<|im_gen_start|>'
7 | EOI_TOKEN = '<|im_gen_end|>'
8 | IMG_TOKEN = '<|im_gen_{:04d}|>'
9 | EOS_TOKEN = '<|endoftext|>'
10 | BOV_TOKEN = '<|vision_start|>'
11 | EOV_TOKEN = '<|vision_end|>'
12 | IMG_PAD_TOKEN = '<|image_pad|>'
13 |
14 |
15 | def remove_mismatched_weights(model, pretrained_state_dict):
16 | own_state = model.state_dict()
17 | mismatch_keys = []
18 |
19 | for name in list(pretrained_state_dict.keys()):
20 | if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape:
21 | mismatch_keys.append(name)
22 | pretrained_state_dict.pop(name)
23 |
24 | return pretrained_state_dict, mismatch_keys
25 |
26 |
27 | def parse_coordinates_colors(cot_text):
28 | """
29 | Parse bounding box coordinates and their colors from the CoT text.
30 |
31 | Args:
32 | cot_text (str): Chain of Thought text containing bounding box information.
33 |
34 | Returns:
35 | list: A list of dictionaries with keys 'x1', 'y1', 'x2', 'y2', and 'color'.
36 | """
37 | # Regular expression to match bounding box and color patterns
38 | pattern = r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|> \((\w+)\)"
39 |
40 | # Parse all matches
41 | matches = re.findall(pattern, cot_text)
42 |
43 | # Extract bounding box coordinates and colors
44 | parsed_data = []
45 | for match in matches:
46 | x1, y1, x2, y2, color = match
47 | parsed_data.append({
48 | 'position': [[int(x1), int(y1)], [int(x2), int(y2)]],
49 | 'color': color
50 | })
51 |
52 | return parsed_data
53 |
54 |
55 | class StopOnToken(StoppingCriteria):
56 | def __init__(self, token_id):
57 | self.token_id = token_id
58 |
59 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
60 | # Check if the last generated token is BOI_TOKEN
61 | return input_ids[0, -1] == self.token_id
62 |
--------------------------------------------------------------------------------
/got/processer/qwen25_vl_processor.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoProcessor
2 |
3 |
4 | BOI_TOKEN = '<|im_gen_start|>'
5 | EOI_TOKEN = '<|im_gen_end|>'
6 | IMG_TOKEN = '<|im_gen_{:04d}|>'
7 |
8 |
9 | def get_processor(model_name, add_gen_token_num=64):
10 | processor = AutoProcessor.from_pretrained(model_name)
11 | add_token_list = [BOI_TOKEN, EOI_TOKEN]
12 | for i in range(add_gen_token_num):
13 | add_token_list.append(IMG_TOKEN.format(i))
14 | processor.tokenizer.add_tokens(add_token_list, special_tokens=True)
15 | return processor
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchvision==0.15.2
3 | hydra-core
4 | omegaconf
5 | transformers==4.49.0
6 | diffusers==0.29.0
7 | sentencepiece
8 | opencv-python
9 | peft==0.13.2
10 | pyrootutils
11 | xformers==0.0.22
12 | accelerate==1.3.0
13 | transformers_stream_generator
14 | tqdm
15 | notebook
16 | numpy==1.21.2
17 | huggingface_hub==0.29.3
--------------------------------------------------------------------------------