├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── blog.md ├── requirements.txt └── src └── finetune.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Qwen2-VL Model Finetuning 2 | 3 | This repository contains code for finetuning the Qwen2-VL vision-language model on custom datasets using HuggingFace datasets. It includes a Gradio web interface for easy interaction and a Python script for command-line execution. 4 | 5 | ## Credits 6 | 7 | This project is based on the work from [https://github.com/zhangfaen/finetune-Qwen2-VL](https://github.com/zhangfaen/finetune-Qwen2-VL). I've significantly modified and extended their original fine-tuning script to work specifically with HuggingFace datasets and added a Gradio interface for ease of use. 8 | 9 | ## Features 10 | 11 | - Finetune Qwen2-VL model on custom HuggingFace datasets 12 | - Gradio web interface for interactive model training 13 | - Command-line script for batch processing 14 | - Customizable training parameters 15 | - Validation during training 16 | 17 | ## Installation 18 | 19 | 1. Clone this repository: 20 | ``` 21 | git clone https://github.com/wjbmattingly/qwen2-vl-finetuning-huggingface.git 22 | cd qwen2-vl-finetuning-huggingface 23 | ``` 24 | 25 | 2. Install the required packages: 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Usage 31 | 32 | ### Gradio Web Interface 33 | 34 | To use the Gradio web interface: 35 | 36 | 1. Run the following command: 37 | ``` 38 | python app.py 39 | ``` 40 | 41 | 2. Open your web browser and navigate to the URL displayed in the console (usually `http://localhost:8083`). 42 | 43 | 3. Use the interface to select your dataset, set training parameters, and start the finetuning process. 44 | 45 | ### Command-line Finetuning 46 | 47 | To finetune the model using the command line: 48 | 49 | 1. Open `src/finetune.py` and modify the parameters in the `train_and_validate` function call at the bottom of the file. 50 | 51 | 2. Run the script: 52 | ```python 53 | from src.finetune import train_and_validate 54 | 55 | train_and_validate( 56 | model_name="Qwen/Qwen2-VL-2B-Instruct", 57 | output_dir="/output", 58 | dataset_name="catmus/medieval", 59 | image_column="im", 60 | text_column="text", 61 | user_text="Convert this image to text", 62 | train_field="train", 63 | val_field="validation", 64 | num_accumulation_steps=2, 65 | eval_steps=1000, 66 | max_steps=10000, 67 | train_batch_size=1, 68 | val_batch_size=1, 69 | device="cuda" 70 | ) 71 | ``` 72 | 73 | This command will start the finetuning process with the specified parameters: 74 | 75 | - Using the Qwen2-VL-2B-Instruct model 76 | - Saving the output to "/output" 77 | - Using the "catmus/medieval" dataset 78 | - Using "im" as the image column and "text" as the text column 79 | - Setting a custom user prompt 80 | - Using "train" and "validation" splits for training and validation 81 | - Setting various training parameters like accumulation steps, evaluation frequency, and batch sizes 82 | ``` 83 | 84 | ## Finetuning Process 85 | 86 | The finetuning process involves the following steps: 87 | 88 | 1. Loading the pre-trained Qwen2-VL model and processor 89 | 2. Preparing the dataset using the custom `HuggingFaceDataset` class 90 | 3. Setting up data loaders for training and validation 91 | 4. Training the model with gradient accumulation and periodic evaluation 92 | 5. Saving the finetuned model 93 | 94 | Key functions in `src/finetune.py`: 95 | 96 | - `train_and_validate`: Main function that orchestrates the finetuning process 97 | - `collate_fn`: Prepares batches of data for the model 98 | - `validate`: Performs validation on the model during training 99 | 100 | ## Customizing the Training 101 | 102 | To finetune on your own HuggingFace dataset: 103 | 104 | 1. Modify the `dataset_name`, `image_column`, and `text_column` parameters in `train_and_validate`. 105 | 2. Adjust other parameters such as `max_steps`, `eval_steps`, and batch sizes as needed. 106 | 107 | ## Roadmap 108 | 109 | 110 | Future improvements and features: 111 | 112 | - [ ] Implement distributed GPU fine-tuning for faster training on multiple GPUs 113 | - [ ] Add support for training on video datasets to leverage Qwen2-VL's video processing capabilities 114 | - [ ] Develop more complex and custom message structures to handle diverse tasks 115 | - [ ] Expand functionality to support tasks beyond Handwritten Text Recognition (HTR) 116 | 117 | 118 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | from src.finetune import train_and_validate 4 | import json 5 | from datasets import load_dataset 6 | 7 | def finetune_model(model_name, output_dir, dataset_name, image_column, text_column, user_text, num_accumulation_steps, eval_steps, max_steps, train_batch_size, val_batch_size, train_select_start, train_select_end, val_select_start, val_select_end, train_field, val_field, min_pixel, max_pixel, image_factor, device): 8 | 9 | # Call the train_and_validate function with the provided parameters 10 | train_and_validate( 11 | model_name=model_name, 12 | output_dir=output_dir, 13 | dataset_name=dataset_name, 14 | image_column=image_column, 15 | text_column=text_column, 16 | device=device, 17 | user_text=user_text, 18 | num_accumulation_steps=num_accumulation_steps, 19 | eval_steps=eval_steps, 20 | max_steps=max_steps, 21 | train_batch_size=train_batch_size, 22 | val_batch_size=val_batch_size, 23 | train_select_start=train_select_start, 24 | train_select_end=train_select_end, 25 | val_select_start=val_select_start, 26 | val_select_end=val_select_end, 27 | train_field=train_field, 28 | val_field=val_field, 29 | min_pixel=min_pixel, 30 | max_pixel=max_pixel, 31 | image_factor=image_factor 32 | ) 33 | 34 | return f"Training completed. Model saved in {output_dir}" 35 | 36 | # Create the Gradio interface 37 | def load_dataset_sample(dataset_name): 38 | dataset = load_dataset(dataset_name, streaming=True) 39 | sample = list(dataset['train'].take(5)) 40 | return sample, list(sample[0].keys()) 41 | 42 | def update_fields(dataset_name): 43 | sample, fields = load_dataset_sample(dataset_name) 44 | return gr.Dropdown(choices=fields, label="Image Column"), gr.Dropdown(choices=fields, label="Text Column"), gr.DataFrame(value=[list(s.values()) for s in sample], headers=list(sample[0].keys())) 45 | 46 | def preview_message_structure(dataset_name, image_column, text_column, user_text): 47 | sample, _ = load_dataset_sample(dataset_name) 48 | image = sample[0][image_column] 49 | assistant_text = sample[0][text_column] 50 | message_structure = { 51 | "messages": [ 52 | { 53 | "role": "user", 54 | "content": [ 55 | {"type": "image", "image": "Image data (not shown)"}, 56 | {"type": "text", "text": user_text} 57 | ] 58 | }, 59 | { 60 | "role": "assistant", 61 | "content": [ 62 | {"type": "text", "text": assistant_text} 63 | ] 64 | } 65 | ] 66 | } 67 | return json.dumps(message_structure, indent=2) 68 | 69 | with gr.Blocks() as iface: 70 | gr.Markdown("# Qwen2-VL Model Finetuning") 71 | gr.Markdown("Finetune the Qwen2-VL model on a specified dataset.") 72 | 73 | with gr.Row(): 74 | dataset_name = gr.Textbox(label="Dataset Name") 75 | load_button = gr.Button("Load Dataset") 76 | 77 | with gr.Row(): 78 | image_column = gr.Dropdown(label="Image Column") 79 | text_column = gr.Dropdown(label="Text Column") 80 | train_field = gr.Dropdown(label="Train Field", choices=["train", "validation", "test"]) 81 | val_field = gr.Dropdown(label="Validation Field", choices=["train", "validation", "test"]) 82 | 83 | sample_data = gr.DataFrame(label="Sample Data") 84 | 85 | load_button.click(update_fields, inputs=[dataset_name], outputs=[image_column, text_column, sample_data]) 86 | 87 | model_name = gr.Dropdown( 88 | label="Model Name", 89 | choices=["Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"], 90 | value="Qwen/Qwen2-VL-2B-Instruct" 91 | ) 92 | 93 | 94 | user_text = gr.Textbox(label="User Instructions", value="Convert this image to text") 95 | preview_button = gr.Button("Preview Message Structure") 96 | message_preview = gr.JSON(label="Message Structure Preview") 97 | 98 | preview_button.click(preview_message_structure, inputs=[dataset_name, image_column, text_column, user_text], outputs=[message_preview]) 99 | 100 | gr.Markdown("## Model Configuration") 101 | with gr.Column(): 102 | device = gr.Dropdown(label="Device", choices=["cuda", "cpu", "mps"], value="cuda") 103 | output_dir = gr.Textbox(label="Output Directory") 104 | 105 | gr.Markdown("## Training Parameters") 106 | with gr.Row(): 107 | with gr.Column(): 108 | gr.Markdown("### Training Steps and Evaluation") 109 | num_accumulation_steps = gr.Number(label="Number of Accumulation Steps", value=2) 110 | eval_steps = gr.Number(label="Evaluation Steps", value=10000) 111 | max_steps = gr.Number(label="Max Steps", value=100000) 112 | with gr.Column(): 113 | gr.Markdown("### Batch Sizes") 114 | train_batch_size = gr.Number(label="Training Batch Size", value=1) 115 | val_batch_size = gr.Number(label="Validation Batch Size", value=1) 116 | with gr.Column(): 117 | gr.Markdown("### Training Data Selection") 118 | train_select_start = gr.Number(label="Training Select Start", value=0) 119 | train_select_end = gr.Number(label="Training Select End", value=100000) 120 | with gr.Column(): 121 | gr.Markdown("### Validation Data Selection") 122 | val_select_start = gr.Number(label="Validation Select Start", value=0) 123 | val_select_end = gr.Number(label="Validation Select End", value=10000) 124 | 125 | gr.Markdown("## Image Processing Settings") 126 | with gr.Row(): 127 | with gr.Column(): 128 | min_pixel = gr.Number(label="Minimum Pixel Size", value=256, precision=0) 129 | max_pixel = gr.Number(label="Maximum Pixel Size", value=384, precision=0) 130 | image_factor = gr.Number(label="Image Factor", value=28, precision=0) 131 | finetune_button = gr.Button("Start Finetuning") 132 | result = gr.Textbox(label="Result") 133 | 134 | finetune_button.click( 135 | finetune_model, 136 | inputs=[model_name,output_dir, dataset_name, image_column, text_column, user_text, num_accumulation_steps, eval_steps, max_steps, train_batch_size, val_batch_size, train_select_start, train_select_end, val_select_start, val_select_end, train_field, val_field, min_pixel, max_pixel, image_factor, device], 137 | outputs=[result] 138 | ) 139 | 140 | # Launch the app 141 | if __name__ == "__main__": 142 | iface.launch(server_port=8083, server_name="compute-50-01") 143 | -------------------------------------------------------------------------------- /blog.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning Qwen2-VL: A Powerful Vision-Language Model for Handwritten Text Recognition 2 | 3 | ## Introduction 4 | 5 | In the rapidly evolving field of AI and computer vision, the Qwen2-VL model stands out as a game-changer. Developed by the Qwen team, this state-of-the-art vision-language model offers unprecedented capabilities in understanding and processing visual information, including handwritten text. Today, we'll explore how to fine-tune Qwen2-VL for specific tasks, with a focus on its potential for Handwritten Text Recognition (HTR). 6 | 7 | ## Why Qwen2-VL Matters 8 | 9 | Qwen2-VL represents a significant leap forward in multimodal AI. It can understand images of various resolutions and aspect ratios, process videos over 20 minutes long, and even operate mobile devices and robots based on visual input and text instructions. For HTR specifically, Qwen2-VL's ability to handle complex visual information makes it one of the best models available, rivaling even closed-source alternatives. 10 | 11 | Key features that make Qwen2-VL exceptional for HTR: 12 | 13 | 1. State-of-the-art performance on visual understanding benchmarks 14 | 2. Ability to handle arbitrary image resolutions 15 | 3. Multilingual support, including understanding text in images across various languages 16 | 4. Advanced architecture with Naive Dynamic Resolution and Multimodal Rotary Position Embedding (M-ROPE) 17 | 18 | ## Step-by-Step Guide to Fine-tuning Qwen2-VL 19 | 20 | Let's walk through the process of fine-tuning Qwen2-VL for your specific HTR task: 21 | 22 | ### Step 1: Setup 23 | 24 | First, ensure you have the necessary dependencies: 25 | 26 | ```bash 27 | pip install git+https://github.com/huggingface/transformers 28 | pip install qwen-vl-utils 29 | ``` 30 | 31 | Next, clone this repository and move into the new folder. 32 | 33 | ```bash 34 | git clone https://github.com/wjbmattingly/qwen2-vl-finetuning-huggingface.git 35 | cd qwen2-vl-finetuning-huggingface 36 | ``` 37 | 38 | Finally, install the required packages for fine-tuning: 39 | 40 | ```bash 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ### Step 2: Prepare Your Dataset 45 | 46 | For HTR, you'll need a dataset of handwritten text images paired with their transcriptions. Organize your data into a format compatible with HuggingFace datasets. For an example, please see [Catmus Medieval](https://huggingface.co/datasets/CATMuS/medieval). 47 | 48 | 54 | 55 | Here, we can see that we have the images in the `im` field and our transcription is in the `text` field. For this tutorial, we will be fine-tuning Qwen 2 VL 2B on this line-level HTR dataset. 56 | 57 | ### Step 3: Load the Model and Processor 58 | 59 | Now that we have a dataset prepared, let's go ahead and start the fine-tuning process. Over the next few sections, we will cover all the steps performed by the function in src.finetune.py, `train_and_validate()`. If you would rather just do all these steps as a single command in Python, please skip ahead to the final section of this blog. 60 | 61 | ```python 62 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 63 | 64 | model = Qwen2VLForConditionalGeneration.from_pretrained( 65 | "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto" 66 | ) 67 | processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") 68 | ``` 69 | 70 | This step initializes the Qwen2-VL model and its associated processor. The `torch_dtype="auto"` and `device_map="auto"` parameters ensure optimal performance based on your hardware. 71 | 72 | ### Step 4: Prepare the Training Data 73 | 74 | We'll use a custom dataset class that formats our data for Qwen2-VL: 75 | 76 | ```python 77 | from torch.utils.data import Dataset 78 | 79 | class HuggingFaceDataset(Dataset): 80 | def __init__(self, dataset, image_column, text_column, user_text="Convert this image to text"): 81 | self.dataset = dataset 82 | self.image_column = image_column 83 | self.text_column = text_column 84 | self.user_text = user_text 85 | 86 | def __len__(self): 87 | return len(self.dataset) 88 | 89 | def __getitem__(self, idx): 90 | item = self.dataset[idx] 91 | image = item[self.image_column] 92 | assistant_text = item[self.text_column] 93 | 94 | return { 95 | "messages": [ 96 | { 97 | "role": "user", 98 | "content": [ 99 | {"type": "image", "image": image}, 100 | {"type": "text", "text": self.user_text} 101 | ] 102 | }, 103 | { 104 | "role": "assistant", 105 | "content": [ 106 | {"type": "text", "text": str(assistant_text)} 107 | ] 108 | } 109 | ] 110 | } 111 | ``` 112 | 113 | This custom `HuggingFaceDataset` class is crucial for preparing our data in a format suitable for fine-tuning the Qwen2-VL model. Let's break down its components: 114 | 115 | 1. **Initialization**: The `__init__` method sets up the dataset with the necessary columns for images and text, as well as a default user prompt. 116 | 117 | 2. **Length**: The `__len__` method returns the total number of items in the dataset, allowing us to iterate over it. 118 | 119 | 3. **Item Retrieval**: The `__getitem__` method is the core of this class. For each index: 120 | - It retrieves an item from the dataset. 121 | - Extracts the image and text from the specified columns. 122 | - Formats the data into a conversation-like structure that Qwen2-VL expects. 123 | 124 | 4. **Conversation Format**: The returned dictionary mimics a conversation with: 125 | - A "user" message containing both the image and a text prompt. 126 | - An "assistant" message containing the transcription (ground truth). 127 | 128 | This structure is essential because it allows the model to learn the connection between the input (image + prompt) and the expected output (transcription). During training, the model will learn to generate the assistant's response based on the user's input, effectively learning to transcribe handwritten text from images. 129 | 130 | This class wraps our HuggingFace dataset, formatting each item as a conversation with a user query (including an image) and an assistant response. 131 | 132 | ### Step 5: Image Processing 133 | 134 | To ensure consistent image processing, we use the `ensure_pil_image` function: 135 | 136 | ```python 137 | from PIL import Image 138 | import base64 139 | from io import BytesIO 140 | 141 | def ensure_pil_image(image, min_size=256): 142 | if isinstance(image, Image.Image): 143 | pil_image = image 144 | elif isinstance(image, str): 145 | if image.startswith('data:image'): 146 | image = image.split(',')[1] 147 | image_data = base64.b64decode(image) 148 | pil_image = Image.open(BytesIO(image_data)) 149 | else: 150 | raise ValueError(f"Unsupported image type: {type(image)}") 151 | 152 | if pil_image.width < min_size or pil_image.height < min_size: 153 | scale = max(min_size / pil_image.width, min_size / pil_image.height) 154 | new_width = int(pil_image.width * scale) 155 | new_height = int(pil_image.height * scale) 156 | pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS) 157 | 158 | return pil_image 159 | ``` 160 | 161 | This function ensures that the input image is a PIL Image object and meets a minimum size requirement. It's particularly useful in our pipeline for several reasons: 162 | 163 | 1. **Flexibility in Input Types**: It can handle different types of image inputs: 164 | - PIL Image objects are used directly. 165 | - Base64-encoded strings (common in web applications) are decoded into images. 166 | - It raises an error for unsupported types, helping to catch potential issues early. 167 | 168 | 2. **Minimum Size Enforcement**: The function ensures that images meet a minimum size (default 256x256 pixels). This is crucial because: 169 | - Many vision models have minimum input size requirements. 170 | - Consistent image sizes can improve training stability and model performance. 171 | - It preserves the aspect ratio while resizing, maintaining image integrity. 172 | 173 | 3. **Quality Preservation**: When resizing is necessary, it uses the LANCZOS algorithm, which is known for producing high-quality resized images. 174 | 175 | 4. **Error Handling**: The function includes error checking and raises informative exceptions, making debugging easier. 176 | 177 | 5. **Integration with Data Pipeline**: This function can be easily integrated into our data loading and preprocessing pipeline, ensuring all images are properly formatted before being fed into the model. 178 | 179 | By using `ensure_pil_image`, we standardize our image inputs, which is crucial for consistent model training and inference. This is especially important when dealing with datasets that might contain images of varying formats and sizes. 180 | 181 | 182 | ### Step 6: Collate Function 183 | 184 | The `collate_fn` is crucial for processing batches of data: 185 | 186 | ```python 187 | def collate_fn(batch, processor, device): 188 | messages = [item['messages'] for item in batch] 189 | texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages] 190 | images = [ensure_pil_image(msg[0]['content'][0]['image']) for msg in messages] 191 | 192 | inputs = processor( 193 | text=texts, 194 | images=images, 195 | padding=True, 196 | return_tensors="pt", 197 | ).to(device) 198 | 199 | input_ids_lists = inputs['input_ids'].tolist() 200 | labels_list = [] 201 | for ids_list in input_ids_lists: 202 | label_ids = [-100] * len(ids_list) 203 | for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list): 204 | label_ids[begin_end_indexs[0]+2:begin_end_indexs[1]+1] = ids_list[begin_end_indexs[0]+2:begin_end_indexs[1]+1] 205 | labels_list.append(label_ids) 206 | 207 | labels_ids = torch.tensor(labels_list, dtype=torch.int64) 208 | 209 | return inputs, labels_ids 210 | ``` 211 | 212 | This `collate_fn` function plays a crucial role in preparing batches of data for training the Qwen2-VL model. Here's a detailed breakdown of its functionality: 213 | 214 | 1. **Batch Processing**: 215 | - It takes a batch of items, where each item contains 'messages' with user and assistant interactions. 216 | - The function processes these messages to create input suitable for the model. 217 | 218 | 2. **Text Processing**: 219 | - It applies a chat template to each message, converting the structured conversation into a format the model can understand. 220 | - The `tokenize=False` parameter ensures we get the formatted text, not tokenized IDs at this stage. 221 | 222 | 3. **Image Processing**: 223 | - For each message, it extracts the image and ensures it's in the correct format using the `ensure_pil_image` function. 224 | - This step standardizes all images, regardless of their original format or size. 225 | 226 | 4. **Model Input Creation**: 227 | - The processor (likely a Qwen2-VL specific processor) is used to create model inputs. 228 | - It combines the processed texts and images, applies padding, and converts everything to PyTorch tensors. 229 | - The resulting inputs are moved to the specified device (CPU or GPU). 230 | 231 | 5. **Label Creation**: 232 | - It creates labels for training, which is crucial for supervised learning. 233 | - The `find_assistant_content_sublist_indexes` function is used to identify the portions of the input that correspond to the assistant's responses. 234 | - Labels are set to -100 for non-assistant parts (which will be ignored during loss calculation) and to the actual token IDs for the assistant's responses. 235 | 236 | 6. **Output**: 237 | - The function returns two elements: 238 | 1. The processed inputs ready for the model. 239 | 2. The labels tensor, aligned with the inputs, for calculating the loss during training. 240 | 241 | This function is essential for transforming raw data into a format that the Qwen2-VL model can use for training, ensuring that both the visual and textual components are properly integrated and aligned. 242 | 243 | ### Step 7: Validation Function 244 | 245 | We include a validation step to monitor model performance: 246 | 247 | ```python 248 | def validate(model, val_loader): 249 | model.eval() 250 | total_val_loss = 0 251 | with torch.no_grad(): 252 | for batch in tqdm(val_loader, desc="Validating"): 253 | inputs, labels = batch 254 | outputs = model(**inputs, labels=labels) 255 | loss = outputs.loss 256 | total_val_loss += loss.item() 257 | 258 | avg_val_loss = total_val_loss / len(val_loader) 259 | model.train() 260 | return avg_val_loss 261 | ``` 262 | 263 | This function calculates the average validation loss across all batches. 264 | 265 | ### Step 8: Training and Validation Loop 266 | 267 | The main training function, `train_and_validate`, brings everything together: 268 | 269 | ```python 270 | def train_and_validate( 271 | model_name, 272 | output_dir, 273 | dataset_name, 274 | image_column, 275 | text_column, 276 | device="cuda", 277 | user_text="Convert this image to text", 278 | min_pixel=256, 279 | max_pixel=384, 280 | image_factor=28, 281 | num_accumulation_steps=2, 282 | eval_steps=10000, 283 | max_steps=100000, 284 | train_select_start=0, 285 | train_select_end=1000, 286 | val_select_start=0, 287 | val_select_end=1000, 288 | train_batch_size=1, 289 | val_batch_size=1, 290 | train_field="train", 291 | val_field="validation" 292 | ): 293 | # Load model and processor 294 | model = Qwen2VLForConditionalGeneration.from_pretrained( 295 | model_name, torch_dtype=torch.bfloat16, device_map=device 296 | ) 297 | processor = AutoProcessor.from_pretrained(model_name, min_pixels=min_pixel*image_factor*image_factor, max_pixels=max_pixel*image_factor*image_factor, padding_side="right") 298 | 299 | # Load and prepare dataset 300 | dataset = load_dataset(dataset_name) 301 | train_dataset = dataset[train_field].shuffle(seed=42).select(range(train_select_start, train_select_end)) 302 | val_dataset = dataset[val_field].shuffle(seed=42).select(range(val_select_start, val_select_end)) 303 | 304 | train_dataset = HuggingFaceDataset(train_dataset, image_column, text_column, user_text) 305 | val_dataset = HuggingFaceDataset(val_dataset, image_column, text_column, user_text) 306 | 307 | # Create data loaders 308 | train_loader = DataLoader( 309 | train_dataset, 310 | batch_size=train_batch_size, 311 | collate_fn=partial(collate_fn, processor=processor, device=device), 312 | shuffle=True 313 | ) 314 | val_loader = DataLoader( 315 | val_dataset, 316 | batch_size=val_batch_size, 317 | collate_fn=partial(collate_fn, processor=processor, device=device) 318 | ) 319 | 320 | # Set up optimizer 321 | model.train() 322 | optimizer = AdamW(model.parameters(), lr=1e-5) 323 | 324 | # Training loop 325 | global_step = 0 326 | progress_bar = tqdm(total=max_steps, desc="Training") 327 | 328 | while global_step < max_steps: 329 | for batch in train_loader: 330 | global_step += 1 331 | inputs, labels = batch 332 | outputs = model(**inputs, labels=labels) 333 | 334 | loss = outputs.loss / num_accumulation_steps 335 | loss.backward() 336 | 337 | if global_step % num_accumulation_steps == 0: 338 | optimizer.step() 339 | optimizer.zero_grad() 340 | 341 | progress_bar.update(1) 342 | progress_bar.set_postfix({"loss": loss.item() * num_accumulation_steps}) 343 | 344 | # Evaluation and model saving 345 | if global_step % eval_steps == 0 or global_step == max_steps: 346 | avg_val_loss = validate(model, val_loader) 347 | save_dir = os.path.join(output_dir, f"model_step_{global_step}") 348 | os.makedirs(save_dir, exist_ok=True) 349 | model.save_pretrained(save_dir) 350 | processor.save_pretrained(save_dir) 351 | model.train() 352 | 353 | if global_step >= max_steps: 354 | save_dir = os.path.join(output_dir, "final") 355 | model.save_pretrained(save_dir) 356 | processor.save_pretrained(save_dir) 357 | break 358 | 359 | progress_bar.close() 360 | ``` 361 | 362 | This function handles the entire training process, including data loading, model training, validation, and saving checkpoints. 363 | 364 | ### Step 9: Running the Fine-tuning Process 365 | 366 | To start the fine-tuning process, you can now call the `train_and_validate` function with your specific parameters: 367 | 368 | ```python 369 | from src.finetune import train_and_validate 370 | 371 | train_and_validate( 372 | model_name="Qwen/Qwen2-VL-2B-Instruct", 373 | output_dir="/output", 374 | dataset_name="CATMuS/medieval", 375 | image_column="im", 376 | text_column="text", 377 | user_text="Transcribe this handwritten text.", 378 | train_field="train", 379 | val_field="validation", 380 | num_accumulation_steps=2, 381 | eval_steps=1000, 382 | max_steps=10000, 383 | train_batch_size=1, 384 | val_batch_size=1, 385 | device="cuda" 386 | ) 387 | ``` 388 | 389 | This will start the fine-tuning process on the specified dataset, saving model checkpoints at regular intervals. 390 | 391 | ## Running Inference with your Fine-Tuned Model 392 | 393 | To use your fine-tuned model, you can load it like this: 394 | 395 | ```python 396 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 397 | 398 | model = Qwen2VLForConditionalGeneration.from_pretrained("/output/final") 399 | processor = AutoProcessor.from_pretrained("/output/final") 400 | 401 | # Now you can use the model for inference 402 | ``` 403 | 404 | To use the model, you can follow the supplied documentation from Qwen2 VL: 405 | 406 | ```python 407 | messages = [ 408 | { 409 | "role": "user", 410 | "content": [ 411 | { 412 | "type": "image", 413 | "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", 414 | }, 415 | {"type": "text", "text": "Describe this image."}, 416 | ], 417 | } 418 | ] 419 | 420 | # Preparation for inference 421 | text = processor.apply_chat_template( 422 | messages, tokenize=False, add_generation_prompt=True 423 | ) 424 | image_inputs, video_inputs = process_vision_info(messages) 425 | inputs = processor( 426 | text=[text], 427 | images=image_inputs, 428 | videos=video_inputs, 429 | padding=True, 430 | return_tensors="pt", 431 | ) 432 | # inputs = inputs.to("cuda") Map the inputs to device, if necessary 433 | 434 | # Inference: Generation of the output 435 | generated_ids = model.generate(**inputs, max_new_tokens=128) # Increase max_new_tokens if you want to generate longer outputs (good for longer texts) 436 | generated_ids_trimmed = [ 437 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 438 | ] 439 | output_text = processor.batch_decode( 440 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 441 | ) 442 | print(output_text) 443 | ``` 444 | 445 | ## Conclusion 446 | 447 | By following these steps, you can fine-tune the Qwen2-VL model for your specific HTR task. The process involves preparing your dataset, setting up the model and processor, and running the training loop with validation steps. 448 | 449 | Remember to monitor the training progress and adjust hyperparameters as needed. Once fine-tuning is complete, you can use the saved model for inference on new handwritten text images. 450 | 451 | Happy transcribing! -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio 2 | torch 3 | git+https://github.com/huggingface/transformers 4 | qwen-vl-utils 5 | datasets 6 | Pillow 7 | tqdm 8 | -------------------------------------------------------------------------------- /src/finetune.py: -------------------------------------------------------------------------------- 1 | """ 2 | The original version of this fine-tuning script came from this source: https://github.com/zhangfaen/finetune-Qwen2-VL. I modified this to align it to work specifically with HuggingFace datasets. I also designed it to specifically with with the Gradio app in the main directory, app.py. I also added a validation step to the training loop. I am deeply indebted and grateful for their work. Without this code, this project would have been substantially more difficult. 3 | """ 4 | import os 5 | 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.optim import AdamW 9 | 10 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 11 | 12 | from datasets import load_dataset 13 | 14 | from PIL import Image 15 | import base64 16 | from io import BytesIO 17 | 18 | from functools import partial 19 | from tqdm import tqdm 20 | 21 | 22 | def find_assistant_content_sublist_indexes(l): 23 | """ 24 | Find the start and end indexes of assistant content sublists within a given list. 25 | 26 | This function searches for specific token sequences that indicate the beginning and end 27 | of assistant content in a tokenized list. It identifies pairs of start and end indexes 28 | for each occurrence of assistant content. 29 | 30 | Args: 31 | l (list): A list of tokens to search through. 32 | 33 | Returns: 34 | list of tuples: A list of (start_index, end_index) pairs indicating the positions 35 | of assistant content sublists within the input list. 36 | 37 | Note: 38 | - The start of assistant content is identified by the sequence [151644, 77091]. 39 | - The end of assistant content is marked by the token 151645. 40 | - This function assumes that each start sequence has a corresponding end token. 41 | """ 42 | start_indexes = [] 43 | end_indexes = [] 44 | 45 | # Iterate through the list to find starting points 46 | for i in range(len(l) - 1): 47 | # Check if the current and next element form the start sequence 48 | if l[i] == 151644 and l[i + 1] == 77091: 49 | start_indexes.append(i) 50 | # Now look for the first 151645 after the start 51 | for j in range(i + 2, len(l)): 52 | if l[j] == 151645: 53 | end_indexes.append(j) 54 | break # Move to the next start after finding the end 55 | 56 | return list(zip(start_indexes, end_indexes)) 57 | 58 | class HuggingFaceDataset(Dataset): 59 | """ 60 | A custom Dataset class for handling HuggingFace datasets with image and text pairs. 61 | 62 | This class is designed to work with datasets that contain image-text pairs, 63 | specifically for use in vision-language models. It processes the data to create 64 | a format suitable for models like Qwen2-VL, structuring each item as a conversation 65 | with a user query (including an image) and an assistant response. 66 | 67 | Attributes: 68 | dataset: The HuggingFace dataset to be wrapped. 69 | image_column (str): The name of the column containing image data. 70 | text_column (str): The name of the column containing text data. 71 | user_text (str): The default user query text to pair with each image. 72 | 73 | """ 74 | def __init__(self, dataset, image_column, text_column, user_text="Convert this image to text"): 75 | self.dataset = dataset 76 | self.image_column = image_column 77 | self.text_column = text_column 78 | self.user_text = user_text 79 | 80 | def __len__(self): 81 | return len(self.dataset) 82 | 83 | def __getitem__(self, idx): 84 | item = self.dataset[idx] 85 | image = item[self.image_column] 86 | assistant_text = item[self.text_column] 87 | 88 | return { 89 | "messages": [ 90 | { 91 | "role": "user", 92 | "content": [ 93 | {"type": "image", "image": image}, 94 | {"type": "text", "text": self.user_text} 95 | ] 96 | }, 97 | { 98 | "role": "assistant", 99 | "content": [ 100 | {"type": "text", "text": str(assistant_text)} 101 | ] 102 | } 103 | ] 104 | } 105 | 106 | def ensure_pil_image(image, min_size=256): 107 | """ 108 | Ensures that the input image is a PIL Image object and meets a minimum size requirement. 109 | 110 | This function handles different input types: 111 | - If the input is already a PIL Image, it's used directly. 112 | - If the input is a string, it's assumed to be a base64-encoded image and is decoded. 113 | - For other input types, a ValueError is raised. 114 | 115 | The function also resizes the image if it's smaller than the specified minimum size, 116 | maintaining the aspect ratio. 117 | 118 | Args: 119 | image (Union[PIL.Image.Image, str]): The input image, either as a PIL Image object 120 | or a base64-encoded string. 121 | min_size (int, optional): The minimum size (in pixels) for both width and height. 122 | Defaults to 256. 123 | 124 | Returns: 125 | PIL.Image.Image: A PIL Image object meeting the size requirements. 126 | 127 | Raises: 128 | ValueError: If the input image type is not supported. 129 | """ 130 | if isinstance(image, Image.Image): 131 | pil_image = image 132 | elif isinstance(image, str): 133 | # Assuming it's a base64 string 134 | if image.startswith('data:image'): 135 | image = image.split(',')[1] 136 | image_data = base64.b64decode(image) 137 | pil_image = Image.open(BytesIO(image_data)) 138 | else: 139 | raise ValueError(f"Unsupported image type: {type(image)}") 140 | 141 | # Check if the image is smaller than the minimum size 142 | if pil_image.width < min_size or pil_image.height < min_size: 143 | # Calculate the scaling factor 144 | scale = max(min_size / pil_image.width, min_size / pil_image.height) 145 | new_width = int(pil_image.width * scale) 146 | new_height = int(pil_image.height * scale) 147 | 148 | # Resize the image 149 | pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS) 150 | 151 | return pil_image 152 | 153 | def collate_fn(batch, processor, device): 154 | """ 155 | Collate function for processing batches of data for the Qwen2-VL model. 156 | 157 | This function prepares the input data for training or inference by processing 158 | the messages, applying chat templates, ensuring images are in the correct format, 159 | and creating input tensors for the model. 160 | 161 | Args: 162 | batch (List[Dict]): A list of dictionaries, each containing 'messages' with text and image data. 163 | processor (AutoProcessor): The processor for the Qwen2-VL model, used for tokenization and image processing. 164 | device (torch.device): The device (CPU or GPU) to which the tensors should be moved. 165 | 166 | Returns: 167 | Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing: 168 | - inputs: A dictionary of input tensors for the model (e.g., input_ids, attention_mask). 169 | - labels_ids: A tensor of label IDs for training, with -100 for non-assistant tokens. 170 | 171 | Note: 172 | This function assumes that each message in the batch contains both text and image data, 173 | and that the first content item in each message is an image. 174 | """ 175 | messages = [item['messages'] for item in batch] 176 | texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages] 177 | 178 | # Ensure all images are PIL Image objects 179 | images = [ensure_pil_image(msg[0]['content'][0]['image']) for msg in messages] 180 | 181 | # Process the text and images using the processor 182 | inputs = processor( 183 | text=texts, 184 | images=images, 185 | padding=True, 186 | return_tensors="pt", 187 | ) 188 | 189 | # Move the inputs to the specified device (CPU or GPU) 190 | inputs = inputs.to(device) 191 | 192 | # Convert input IDs to a list of lists for easier processing 193 | input_ids_lists = inputs['input_ids'].tolist() 194 | labels_list = [] 195 | for ids_list in input_ids_lists: 196 | # Initialize label IDs with -100 (ignored in loss calculation) 197 | label_ids = [-100] * len(ids_list) 198 | # Find the indexes of assistant content in the input IDs 199 | for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list): 200 | # Set the label IDs for assistant content, skipping the first two tokens 201 | label_ids[begin_end_indexs[0]+2:begin_end_indexs[1]+1] = ids_list[begin_end_indexs[0]+2:begin_end_indexs[1]+1] 202 | labels_list.append(label_ids) 203 | 204 | # Convert the labels list to a tensor 205 | labels_ids = torch.tensor(labels_list, dtype=torch.int64) 206 | 207 | # Return the processed inputs and label IDs 208 | return inputs, labels_ids 209 | 210 | def validate(model, val_loader): 211 | """ 212 | Validate the model on the validation dataset. 213 | 214 | Args: 215 | model (nn.Module): The model to validate. 216 | val_loader (DataLoader): DataLoader for the validation dataset. 217 | 218 | Returns: 219 | float: The average validation loss. 220 | 221 | This function sets the model to evaluation mode, performs a forward pass 222 | on the validation data without gradient computation, calculates the loss, 223 | and returns the average validation loss across all batches. 224 | """ 225 | model.eval() 226 | total_val_loss = 0 227 | with torch.no_grad(): 228 | for batch in tqdm(val_loader, desc="Validating"): 229 | inputs, labels = batch 230 | outputs = model(**inputs, labels=labels) 231 | loss = outputs.loss 232 | total_val_loss += loss.item() 233 | 234 | avg_val_loss = total_val_loss / len(val_loader) 235 | model.train() 236 | return avg_val_loss 237 | 238 | def train_and_validate( 239 | model_name, 240 | output_dir, 241 | dataset_name, 242 | image_column, 243 | text_column, 244 | device="cuda", 245 | user_text="Convert this image to text", 246 | min_pixel=256, 247 | max_pixel=384, 248 | image_factor=28, 249 | num_accumulation_steps=2, 250 | eval_steps=10000, 251 | max_steps=100000, 252 | train_select_start=0, 253 | train_select_end=1000, 254 | val_select_start=0, 255 | val_select_end=1000, 256 | train_batch_size=1, 257 | val_batch_size=1, 258 | train_field="train", 259 | val_field="validation" 260 | ): 261 | """ 262 | Train and validate a Qwen2VL model on a specified dataset. 263 | 264 | Args: 265 | model_name (str): Name of the pre-trained model to use. 266 | output_dir (str): Directory to save the trained model. 267 | dataset_name (str): Name of the dataset to use for training and validation. 268 | image_column (str): Name of the column containing image data in the dataset. 269 | text_column (str): Name of the column containing text data in the dataset. 270 | device (str): Device to use for training ('cuda' or 'cpu'). 271 | user_text (str): Default text prompt for the user input. 272 | min_pixel (int): Minimum pixel size for image processing. 273 | max_pixel (int): Maximum pixel size for image processing. 274 | image_factor (int): Factor for image size calculation. 275 | num_accumulation_steps (int): Number of steps for gradient accumulation. 276 | eval_steps (int): Number of steps between evaluations. 277 | max_steps (int): Maximum number of training steps. 278 | train_select_start (int): Starting index for selecting training data. 279 | train_select_end (int): Ending index for selecting training data. 280 | val_select_start (int): Starting index for selecting validation data. 281 | val_select_end (int): Ending index for selecting validation data. 282 | train_batch_size (int): Batch size for training. 283 | val_batch_size (int): Batch size for validation. 284 | train_field (str): Field name for training data in the dataset. 285 | val_field (str): Field name for validation data in the dataset. 286 | 287 | Returns: 288 | None 289 | """ 290 | model = Qwen2VLForConditionalGeneration.from_pretrained( 291 | model_name, torch_dtype=torch.bfloat16, 292 | device_map=device 293 | ) 294 | 295 | processor = AutoProcessor.from_pretrained(model_name, min_pixels=min_pixel*image_factor*image_factor, max_pixels=max_pixel*image_factor*image_factor, padding_side="right") 296 | 297 | # Load and split the dataset 298 | dataset = load_dataset(dataset_name) 299 | train_dataset = dataset[train_field].shuffle(seed=42).select(range(train_select_start, train_select_end)) 300 | val_dataset = dataset[val_field].shuffle(seed=42).select(range(val_select_start, val_select_end)) 301 | 302 | train_dataset = HuggingFaceDataset(train_dataset, image_column, text_column, user_text) 303 | val_dataset = HuggingFaceDataset(val_dataset, image_column, text_column, user_text) 304 | 305 | train_loader = DataLoader( 306 | train_dataset, 307 | batch_size=train_batch_size, 308 | collate_fn=partial(collate_fn, processor=processor, device=device), 309 | shuffle=True 310 | ) 311 | 312 | val_loader = DataLoader( 313 | val_dataset, 314 | batch_size=val_batch_size, 315 | collate_fn=partial(collate_fn, processor=processor, device=device) 316 | ) 317 | 318 | model.train() 319 | optimizer = AdamW(model.parameters(), lr=1e-5) 320 | 321 | global_step = 0 322 | 323 | progress_bar = tqdm(total=max_steps, desc="Training") 324 | 325 | while global_step < max_steps: 326 | for batch in train_loader: 327 | global_step += 1 328 | inputs, labels = batch 329 | outputs = model(**inputs, labels=labels) 330 | 331 | loss = outputs.loss / num_accumulation_steps 332 | loss.backward() 333 | 334 | if global_step % num_accumulation_steps == 0: 335 | optimizer.step() 336 | optimizer.zero_grad() 337 | 338 | progress_bar.update(1) 339 | progress_bar.set_postfix({"loss": loss.item() * num_accumulation_steps}) 340 | 341 | # Perform evaluation and save model every EVAL_STEPS 342 | if global_step % eval_steps == 0 or global_step == max_steps: 343 | avg_val_loss = validate(model, val_loader) 344 | 345 | # Save the model and processor 346 | save_dir = os.path.join(output_dir, f"model_step_{global_step}") 347 | os.makedirs(save_dir, exist_ok=True) 348 | model.save_pretrained(save_dir) 349 | processor.save_pretrained(save_dir) 350 | 351 | model.train() # Set the model back to training mode 352 | 353 | if global_step >= max_steps: 354 | save_dir = os.path.join(output_dir, f"final") 355 | model.save_pretrained(save_dir) 356 | processor.save_pretrained(save_dir) 357 | break 358 | 359 | if global_step >= max_steps: 360 | save_dir = os.path.join(output_dir, f"final") 361 | model.save_pretrained(save_dir) 362 | processor.save_pretrained(save_dir) 363 | break 364 | 365 | progress_bar.close() --------------------------------------------------------------------------------