├── images ├── gpt4o-01.JPG ├── gpt4o-02.JPG ├── gpt4o-03.JPG ├── images_llava_qwen_2b_chinese-clip.jpg ├── llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16 │ ├── 1.PNG │ ├── 2.PNG │ ├── 3.PNG │ ├── 4.PNG │ ├── 5.PNG │ ├── 6.PNG │ ├── 7.PNG │ ├── 8.PNG │ └── 9.PNG ├── llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16 │ ├── llava-qwen2-7b-openaiclipL14-336-fp16-01.PNG │ ├── llava-qwen2-7b-openaiclipL14-336-fp16-02.PNG │ ├── llava-qwen2-7b-openaiclipL14-336-fp16-03.PNG │ ├── llava-qwen2-7b-openaiclipL14-336-fp16-04.PNG │ ├── llava-qwen2-7b-openaiclipL14-336-fp16-05.PNG │ └── llava-qwen2-7b-openaiclipL14-336-fp16-06.PNG └── llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16 │ ├── llava-qwen2-7b-OFA-Syschinese-clip-fp16-01.PNG │ ├── llava-qwen2-7b-OFA-Syschinese-clip-fp16-02.PNG │ ├── llava-qwen2-7b-OFA-Syschinese-clip-fp16-03.PNG │ ├── llava-qwen2-7b-OFA-Syschinese-clip-fp16-04.PNG │ ├── llava-qwen2-7b-OFA-Syschinese-clip-fp16-05.PNG │ └── llava-qwen2-7b-OFA-Syschinese-clip-fp16-06.PNG ├── llava-Qwen2-7B-Instruct-Chinese-CLIP训练手册.xlsx ├── train_llava ├── custom_trainer.py ├── util.py ├── data.py └── data_websend.py ├── ds_z2_config.json ├── merge_lora.py ├── web_dataset_backend.py ├── run.sh ├── infer.py ├── run.py ├── README.md └── LICENSE /images/gpt4o-01.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/gpt4o-01.JPG -------------------------------------------------------------------------------- /images/gpt4o-02.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/gpt4o-02.JPG -------------------------------------------------------------------------------- /images/gpt4o-03.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/gpt4o-03.JPG -------------------------------------------------------------------------------- /images/images_llava_qwen_2b_chinese-clip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/images_llava_qwen_2b_chinese-clip.jpg -------------------------------------------------------------------------------- /llava-Qwen2-7B-Instruct-Chinese-CLIP训练手册.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/llava-Qwen2-7B-Instruct-Chinese-CLIP训练手册.xlsx -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/1.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/2.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/3.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/4.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/5.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/6.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/6.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/7.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/7.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/8.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/8.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/9.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-chineseOCR_pri_fly_SWH_memechinese_lora_0716_warmup0_1_fp16/9.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-01.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-01.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-02.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-02.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-03.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-03.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-04.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-04.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-05.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-05.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-06.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-openaiclip-memechinesebqb_merged_0709_fp16/llava-qwen2-7b-openaiclipL14-336-fp16-06.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-01.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-01.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-02.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-02.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-03.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-03.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-04.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-04.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-05.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-05.PNG -------------------------------------------------------------------------------- /images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-06.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reilxlx/llava-Qwen2-7B-Instruct-Chinese-CLIP/HEAD/images/llava-qwen-2-7b-OFA-Syschinese-clip-memechinesebqb_merged_0708_fp16/llava-qwen2-7b-OFA-Syschinese-clip-fp16-06.PNG -------------------------------------------------------------------------------- /train_llava/custom_trainer.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | from transformers.trainer_pt_utils import ShardSampler 3 | from typing import Optional 4 | import torch 5 | from transformers.trainer_utils import has_length 6 | 7 | 8 | class WebTrainer(Trainer): 9 | 10 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 11 | if self.train_dataset is None or not has_length(self.train_dataset): 12 | return None 13 | 14 | return ShardSampler(self.train_dataset) 15 | -------------------------------------------------------------------------------- /ds_z2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } -------------------------------------------------------------------------------- /merge_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel, LoraConfig 3 | from transformers import LlavaForConditionalGeneration 4 | model_name = "/替换为你的基础模型路径" 5 | LORA_R = 32 6 | LORA_ALPHA = 64 7 | LORA_DROPOUT = 0.05 8 | TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] 9 | lora_config = LoraConfig( 10 | r=LORA_R, 11 | lora_alpha=LORA_ALPHA, 12 | target_modules=TARGET_MODULES, 13 | lora_dropout=LORA_DROPOUT, 14 | bias="none", 15 | task_type="CAUSAL_LM", 16 | modules_to_save=["multi_modal_projector"], 17 | ) 18 | model = LlavaForConditionalGeneration.from_pretrained(model_name) 19 | model = PeftModel.from_pretrained(model, "/替换为你的lora模型路径", config=lora_config, adapter_name='lora') 20 | 21 | model.cpu() 22 | model.eval() 23 | base_model = model.get_base_model() 24 | base_model.eval() 25 | model.merge_and_unload() 26 | 27 | base_model.save_pretrained("/保存的完整模型路径") -------------------------------------------------------------------------------- /web_dataset_backend.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | 3 | import pandas as pd 4 | import uvicorn 5 | from fastapi import FastAPI, HTTPException 6 | from fastapi.middleware.cors import CORSMiddleware 7 | 8 | from train_llava.data_websend import SendDatasetByWeb 9 | 10 | 11 | webdatasetsend = SendDatasetByWeb( 12 | model_name_or_path="test_model/model001", 13 | dataset_dir="data/liuhaotian/LLaVA-CC3M-Pretrain-595K", 14 | cache_dir="data/cache_data", 15 | num_proc=10 16 | ) 17 | 18 | 19 | app = FastAPI() 20 | 21 | origins = ["*"] 22 | 23 | app.add_middleware( 24 | CORSMiddleware, 25 | allow_origins=origins, 26 | allow_credentials=True, 27 | allow_methods=["*"], 28 | allow_headers=["*"], 29 | ) 30 | 31 | 32 | @app.get("/len") 33 | async def get_len(): 34 | return len(webdatasetsend) 35 | 36 | 37 | @app.get("/slice") 38 | async def get_slice(index: int) -> dict[Any, Any]: 39 | return webdatasetsend[index] 40 | 41 | 42 | if __name__ == "__main__": 43 | uvicorn.run(app=app, host="0.0.0.0", port=7001, reload=False) 44 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | deepspeed --include localhost:0,1,2,3,4,5,6,7 /run.py \ 2 | --deepspeed ds_z2_config.json \ 3 | --model_name_or_path /home/llava_model/llava-qwen2-7b-OFA-Syschinese-clip/ \ 4 | --train_type use_lora \ 5 | --data_path /home/data/meme_chineseBRB \ 6 | --json_name meme_chineseBQB.json \ 7 | --image_folder images \ 8 | --gpu_nums 8 \ 9 | --lora_r 32 \ 10 | --lora_alpha 64 \ 11 | --remove_unused_columns False \ 12 | --web_host_ip "0.0.0.0" \ 13 | --build_data_from_web False \ 14 | --bf16 True \ 15 | --tf32 True \ 16 | --output_dir /home/lora/llava-qwen-2-7b-clip_original-memechinesebrb_lora_0709_warmup0_1/ \ 17 | --num_train_epochs 5 \ 18 | --per_device_train_batch_size 1 \ 19 | --per_device_eval_batch_size 1 \ 20 | --gradient_accumulation_steps 8 \ 21 | --warmup_ratio 0.1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "epoch" \ 24 | --save_total_limit 1 \ 25 | --report_to none \ 26 | --high_lr 1e-3 \ 27 | --low_lr 2e-5 \ 28 | --logging_steps 1 \ 29 | --model_max_length 2048 30 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from transformers import LlavaForConditionalGeneration, AutoProcessor 2 | import torch 3 | from PIL import Image 4 | 5 | raw_model_name_or_path = "/保存的完整模型路径" 6 | model = LlavaForConditionalGeneration.from_pretrained(raw_model_name_or_path, device_map="cuda:0", torch_dtype=torch.bfloat16) 7 | processor = AutoProcessor.from_pretrained(raw_model_name_or_path) 8 | model.eval() 9 | 10 | def build_model_input(model, processor): 11 | messages = [ 12 | {"role": "system", "content": "You are a helpful assistant."}, 13 | {"role": "user", "content": "\n 使用中文描述图片中的信息"} 14 | ] 15 | prompt = processor.tokenizer.apply_chat_template( 16 | messages, tokenize=False, add_generation_prompt=True 17 | ) 18 | image = Image.open("01.PNG") 19 | inputs = processor(text=prompt, images=image, return_tensors="pt", return_token_type_ids=False) 20 | 21 | for tk in inputs.keys(): 22 | inputs[tk] = inputs[tk].to(model.device) 23 | generate_ids = model.generate(**inputs, max_new_tokens=200) 24 | 25 | generate_ids = [ 26 | oid[len(iids):] for oid, iids in zip(generate_ids, inputs.input_ids) 27 | ] 28 | gen_text = processor.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] 29 | return gen_text 30 | build_model_input(model, processor) -------------------------------------------------------------------------------- /train_llava/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | # copy code from https://github.com/huggingface/peft/blob/2f5360a7da22a236b5ad4c059572fff5321c867c/src/peft/peft_model.py#L617 4 | def get_nb_trainable_parameters(model:nn.Module) -> tuple[int, int]: 5 | r""" 6 | Returns the number of trainable parameters and the number of all parameters in the model. 7 | """ 8 | trainable_params = 0 9 | all_param = 0 10 | for _, param in model.named_parameters(): 11 | num_params = param.numel() 12 | # if using DS Zero 3 and the weights are initialized empty 13 | if num_params == 0 and hasattr(param, "ds_numel"): 14 | num_params = param.ds_numel 15 | 16 | # Due to the design of 4bit linear layers from bitsandbytes 17 | # one needs to multiply the number of parameters by 2 to get 18 | # the correct number of parameters 19 | if param.__class__.__name__ == "Params4bit": 20 | if hasattr(param, "element_size"): 21 | num_bytes = param.element_size() 22 | elif not hasattr(param, "quant_storage"): 23 | num_bytes = 1 24 | else: 25 | num_bytes = param.quant_storage.itemsize 26 | num_params = num_params * 2 * num_bytes 27 | 28 | all_param += num_params 29 | if param.requires_grad: 30 | trainable_params += num_params 31 | 32 | return trainable_params, all_param 33 | 34 | 35 | # copy code from https://github.com/huggingface/peft/blob/2f5360a7da22a236b5ad4c059572fff5321c867c/src/peft/peft_model.py#L647 36 | def print_trainable_parameters(model: nn.Module) -> None: 37 | """ 38 | Prints the number of trainable parameters in the model. 39 | 40 | Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from 41 | num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns 42 | (trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model. 43 | For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for 44 | prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number 45 | of trainable parameters of the backbone transformer model which can be different. 46 | """ 47 | trainable_params, all_param = get_nb_trainable_parameters(model) 48 | 49 | print( 50 | f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}" 51 | ) 52 | 53 | 54 | def print_trainable_parameters_name(model: nn.Module) -> None: 55 | """ 56 | Prints the number of trainable parameters in the model. 57 | 58 | Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from 59 | num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns 60 | (trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model. 61 | For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for 62 | prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number 63 | of trainable parameters of the backbone transformer model which can be different. 64 | """ 65 | trainable_params, all_param = get_nb_trainable_parameters(model) 66 | 67 | print( 68 | f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}" 69 | ) 70 | print("\nTrainable parameter names:") 71 | for name, param in model.named_parameters(): 72 | if param.requires_grad: 73 | print(name) 74 | 75 | -------------------------------------------------------------------------------- /train_llava/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from glob import glob 4 | from pathlib import Path 5 | from typing import Dict, List, Tuple 6 | 7 | import pandas as pd 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from transformers import AutoProcessor 12 | import os 13 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 14 | 15 | 16 | @dataclass 17 | class QaImageOutput: 18 | q_input_ids: torch.long 19 | pixel_values: torch.long 20 | a_input_ids: torch.long 21 | 22 | 23 | class LlavaDataset(Dataset): 24 | def __init__(self, data_path: str, json_name: str, image_folder: str) -> None: 25 | super().__init__() 26 | 27 | self.chat_data, self.image_dir = self.build_dataset(data_path, json_name, image_folder) 28 | 29 | def build_dataset(self, data_path: str, json_name: str, image_folder: str) -> Tuple[List[Dict], Path]: 30 | data_dir = Path(data_path) 31 | chat_file = data_dir.joinpath(json_name) 32 | image_dir = data_dir.joinpath(image_folder) 33 | 34 | chat_data = pd.read_json(chat_file).to_dict(orient="records") 35 | 36 | return chat_data, image_dir 37 | 38 | def __len__(self): 39 | return len(self.chat_data) 40 | 41 | def __getitem__(self, index) -> Tuple[str, str, Path]: 42 | cur_data = self.chat_data[index] 43 | conversations = cur_data.get("conversations") 44 | 45 | human_input = conversations[0].get("value") 46 | chatbot_output = conversations[1].get("value") 47 | 48 | image_path = self.image_dir.joinpath(cur_data.get("image")) 49 | return human_input, chatbot_output, image_path 50 | 51 | 52 | def build_qaimage( 53 | processor: AutoProcessor, q_text: str, a_text: str, image_path: Path, model_max_length: int 54 | ): 55 | messages = [ 56 | {"role": "system", "content": "You are a helpful assistant."}, 57 | {"role": "user", "content": q_text}, 58 | ] 59 | prompt = processor.tokenizer.apply_chat_template( 60 | messages, tokenize=False, add_generation_prompt=True 61 | ) 62 | image_file = image_path 63 | raw_image = Image.open(image_file).convert("RGB") 64 | inputs = processor(prompt, raw_image, return_tensors="pt") 65 | 66 | a_input_ids = processor.tokenizer( 67 | a_text, 68 | return_tensors="pt", 69 | padding="longest", 70 | truncation=True, 71 | max_length=model_max_length, 72 | )["input_ids"].long() 73 | 74 | q_input_ids = inputs.get("input_ids")[:, :model_max_length] 75 | pixel_values = inputs.get("pixel_values") 76 | 77 | res = QaImageOutput( 78 | q_input_ids=q_input_ids, 79 | pixel_values=pixel_values, 80 | a_input_ids=a_input_ids, 81 | ) 82 | return res 83 | 84 | 85 | class TrainLLavaModelCollator: 86 | def __init__(self, processor: AutoProcessor, IGNORE_INDEX: int, model_max_length: int) -> None: 87 | self.processor = processor 88 | self.ingnore_index = IGNORE_INDEX 89 | self.model_max_length = model_max_length 90 | 91 | def convert_one_piece( 92 | self, 93 | q_input_ids: torch.long, 94 | a_input_ids: torch.long, 95 | ): 96 | input_ids = torch.concat( 97 | [ 98 | q_input_ids, 99 | a_input_ids, 100 | torch.tensor(self.processor.tokenizer.eos_token_id, dtype=torch.long).reshape(1, -1), 101 | ], 102 | axis=1, 103 | ) 104 | labels = torch.concat( 105 | [ 106 | torch.full(q_input_ids.shape, self.ingnore_index, dtype=torch.long), 107 | a_input_ids, 108 | torch.tensor(self.processor.tokenizer.eos_token_id, dtype=torch.long).reshape(1, -1), 109 | ], 110 | axis=1, 111 | ) 112 | input_ids = input_ids[:, : self.model_max_length] 113 | labels = labels[:, : self.model_max_length] 114 | 115 | return input_ids, labels 116 | 117 | def __call__(self, features: List) -> Dict[str, torch.Tensor]: 118 | input_ids_list = [] 119 | labels_list = [] 120 | pixel_values = [] 121 | max_input_len_list = [] 122 | image_paths = [] 123 | 124 | for feature in features: 125 | qaimage_output = build_qaimage( 126 | self.processor, feature[0], feature[1], feature[2], self.model_max_length 127 | ) 128 | temp_input_ids, temp_labels = self.convert_one_piece( 129 | qaimage_output.q_input_ids, qaimage_output.a_input_ids 130 | ) 131 | max_input_len_list.append(temp_input_ids.shape[1]) 132 | input_ids_list.append(temp_input_ids) 133 | labels_list.append(temp_labels) 134 | pixel_values.append(qaimage_output.pixel_values) 135 | image_paths.append(feature[2]) 136 | 137 | final_input_ids = torch.concat( 138 | [ 139 | torch.concat( 140 | [ 141 | torch.full( 142 | (1, self.model_max_length - max_input_len_list[index]), 143 | self.processor.tokenizer.pad_token_id,dtype=torch.long 144 | ), 145 | value, 146 | ], 147 | axis=1, 148 | ) 149 | for index, value in enumerate(input_ids_list) 150 | ] 151 | ) 152 | final_labels = torch.concat( 153 | [ 154 | torch.concat( 155 | [ 156 | torch.full( 157 | (1, self.model_max_length - max_input_len_list[index]), 158 | self.ingnore_index, dtype=torch.long 159 | ), 160 | value, 161 | ], 162 | axis=1, 163 | ) 164 | for index, value in enumerate(labels_list) 165 | ] 166 | ) 167 | final_pixel_values = torch.concat(pixel_values, axis=0) 168 | attention_mask = torch.ones_like(final_input_ids) 169 | attention_mask[final_input_ids.long() == self.processor.tokenizer.pad_token_id] = 0 170 | return { 171 | "input_ids": final_input_ids, 172 | "labels": final_labels, 173 | "pixel_values": final_pixel_values, 174 | "attention_mask": attention_mask, 175 | } 176 | 177 | 178 | if __name__ == "__main__": 179 | data_path = "/home/models/weight/data/liuhaotianLLaVA-Pretrain" 180 | json_name = "blip_laion_cc_sbu_558k_cleaned_v2.json" 181 | image_folder = "images" 182 | 183 | llavadataset = LlavaDataset(data_path, json_name, image_folder) 184 | print(len(llavadataset)) 185 | print(llavadataset[100]) 186 | -------------------------------------------------------------------------------- /train_llava/data_websend.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | from typing import Any, Dict, List 4 | 5 | from datasets import load_dataset 6 | from transformers import LlavaProcessor, AutoProcessor 7 | 8 | from .data import build_qaimage, TrainLLavaModelCollator, QaImageOutput 9 | import requests 10 | import torch 11 | import random 12 | 13 | 14 | def preprocess_sub_data(dataset_dir, examples): 15 | image_path = examples["image"] 16 | 17 | image_path = str(dataset_dir.joinpath("images_dl").joinpath(image_path)) 18 | 19 | conversations = [i for i in examples["conversations"]] 20 | human_input = conversations[0].get("value") 21 | chatbot_output = conversations[1].get("value") 22 | 23 | examples["image_path"] = image_path 24 | examples["human_input"] = human_input 25 | examples["chatbot_output"] = chatbot_output 26 | 27 | return examples 28 | 29 | 30 | def preprocess_convert2vector(model_processor, examples): 31 | result = build_qaimage( 32 | model_processor, 33 | examples["human_input"], 34 | examples["chatbot_output"], 35 | examples["image_path"], 36 | ) 37 | examples["q_input_ids"] = result.q_input_ids 38 | examples["pixel_values"] = result.pixel_values 39 | examples["a_input_ids"] = result.a_input_ids 40 | 41 | return examples 42 | 43 | 44 | class SendDatasetByWeb: 45 | def __init__( 46 | self, model_name_or_path: str, dataset_dir: str, cache_dir: str, num_proc: int 47 | ) -> None: 48 | 49 | # 加载默认的processor 50 | self.model_processor = LlavaProcessor.from_pretrained(model_name_or_path) 51 | 52 | # 获得当前的数据位置 53 | self.dataset_dir = Path(dataset_dir) 54 | 55 | # 创建数据集 56 | self.rawdataset_sub_data1 = self.build_dataset(cache_dir, num_proc) 57 | 58 | # 创建随机映射表(在训练阶段,不需要再把数据打乱了) 59 | self.random_map = self.build_random_idmap() 60 | 61 | def build_random_idmap(self): 62 | random.seed(42) 63 | data_size = len(self) 64 | random_id_list = random.choices(range(data_size), k=data_size) 65 | random_map = {k: v for k, v in enumerate(random_id_list)} 66 | return random_map 67 | 68 | 69 | 70 | def build_dataset(self, cache_dir, num_proc): 71 | rawdataset = load_dataset( 72 | "json", 73 | data_files={"train": [str(self.dataset_dir.joinpath("chat.json"))]}, 74 | cache_dir=cache_dir, # "data/cache_data", 75 | ) 76 | 77 | preprocess_convert2vector_partial = partial( 78 | preprocess_convert2vector, self.model_processor 79 | ) 80 | 81 | preprocess_sub_data_partial = partial(preprocess_sub_data, self.dataset_dir) 82 | 83 | rawdataset = rawdataset["train"] # .select(range(10)) 84 | 85 | rawdataset_sub_data1 = rawdataset.map( 86 | function=preprocess_sub_data_partial, num_proc=num_proc, batch_size=3 87 | ) 88 | 89 | rawdataset_sub_data1 = self.rawdataset_sub_data1.map( 90 | function=preprocess_convert2vector_partial, num_proc=num_proc, batch_size=3 91 | ) 92 | return rawdataset_sub_data1 93 | 94 | def __len__(self) -> int: 95 | return len(self.rawdataset_sub_data1) 96 | 97 | def __getitem__(self, index) -> Dict: 98 | return self.rawdataset_sub_data1[index] 99 | 100 | 101 | class DatasetReceiveByWeb: 102 | def __init__(self, host_ip: str = "0.0.0.0"): 103 | self.host_ip = host_ip 104 | 105 | def __len__(self): 106 | return self.get_len_from_web(self.host_ip) 107 | 108 | def __getitem__(self, index): 109 | data = self.get_slice_from_web(index, self.host_ip) 110 | return data 111 | 112 | def get_slice_from_web(self, index: int, host: str = "0.0.0.0"): # 113 | 114 | web = requests.get(url=f"http://{host}:7001/slice", params={"index": index}) 115 | json_data = web.json() 116 | return json_data 117 | 118 | def get_len_from_web(self, host: str = "0.0.0.0"): 119 | web = requests.get(url=f"http://{host}:7001/len") 120 | # json_data = web.json() 121 | return web.json() 122 | 123 | 124 | class TrainLlavaModelCollatorByWeb(TrainLLavaModelCollator): 125 | def __call__(self, features: List) -> Dict[str, torch.Tensor]: 126 | input_ids_list = [] 127 | labels_list = [] 128 | pixel_values = [] 129 | max_input_len_list = [] 130 | 131 | for feature in features: 132 | qaimage_output = QaImageOutput( 133 | q_input_ids=torch.tensor(feature["q_input_ids"]), 134 | pixel_values=torch.tensor(feature["pixel_values"]), 135 | a_input_ids=torch.tensor(feature["a_input_ids"]), 136 | ) 137 | 138 | # build_qaimage( 139 | # self.processor, feature[0], feature[1], feature[2] 140 | # ) 141 | temp_input_ids, temp_labels = self.convert_one_piece( 142 | qaimage_output.q_input_ids, qaimage_output.a_input_ids 143 | ) 144 | max_input_len_list.append(temp_input_ids.shape[1]) 145 | input_ids_list.append(temp_input_ids) 146 | labels_list.append(temp_labels) 147 | pixel_values.append(qaimage_output.pixel_values) 148 | 149 | max_input_len = max(max_input_len_list) 150 | 151 | final_input_ids = torch.concat( 152 | [ 153 | torch.concat( 154 | [ 155 | torch.full( 156 | (1, max_input_len - max_input_len_list[index]), 157 | self.processor.tokenizer.pad_token_id, 158 | ), 159 | value, 160 | ], 161 | axis=1, 162 | ) 163 | for index, value in enumerate(input_ids_list) 164 | ] 165 | ) 166 | final_labels = torch.concat( 167 | [ 168 | torch.concat( 169 | [ 170 | torch.full( 171 | (1, max_input_len - max_input_len_list[index]), 172 | self.ingnore_index, 173 | ), 174 | value, 175 | ], 176 | axis=1, 177 | ) 178 | for index, value in enumerate(labels_list) 179 | ] 180 | ) 181 | final_pixel_values = torch.concat(pixel_values, axis=0) 182 | attention_mask = torch.ones_like(final_input_ids) 183 | attention_mask[final_input_ids == self.processor.tokenizer.pad_token_id] = 0 184 | return { 185 | "input_ids": final_input_ids, 186 | "labels": final_labels, 187 | "pixel_values": final_pixel_values, 188 | "attention_mask": attention_mask, 189 | } 190 | 191 | 192 | __all__ = ["TrainLlavaModelCollatorByWeb", "SendDatasetByWeb", "DatasetReceiveByWeb"] 193 | 194 | if __name__ == "__main__": 195 | from train_llava.data_websend import ( 196 | DatasetReceiveByWeb, 197 | TrainLlavaModelCollatorByWeb, 198 | ) 199 | from transformers import LlavaProcessor 200 | 201 | processor = LlavaProcessor.from_pretrained("test_model/model001") 202 | 203 | web_dataset = DatasetReceiveByWeb("10.136.0.65") 204 | len(web_dataset) 205 | 206 | tlmcw = TrainLlavaModelCollatorByWeb(processor, -100) 207 | result = tlmcw([web_dataset[0], web_dataset[1]]) 208 | print(result) 209 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | from dataclasses import dataclass, field 5 | from functools import partial 6 | from typing import Dict, List, Optional, Sequence 7 | 8 | import torch 9 | import transformers 10 | from datasets import load_dataset 11 | from torch.utils.data import Dataset 12 | from tqdm import tqdm 13 | from transformers import ( 14 | AutoProcessor, 15 | DataCollatorForSeq2Seq, 16 | LlavaForConditionalGeneration, 17 | LlavaProcessor, 18 | Trainer, 19 | TrainingArguments, 20 | get_linear_schedule_with_warmup, 21 | ) 22 | 23 | from train_llava.custom_trainer import WebTrainer 24 | from train_llava.data import LlavaDataset, TrainLLavaModelCollator 25 | from train_llava.data_websend import DatasetReceiveByWeb, TrainLlavaModelCollatorByWeb 26 | from train_llava.util import print_trainable_parameters, print_trainable_parameters_name 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | @dataclass 31 | class ModelArguments: 32 | model_name_or_path: Optional[str] = field(default="test_model/model001") 33 | train_type: Optional[str] = field( 34 | default="none", 35 | metadata={ 36 | "help": """ 37 | 1. use_lora:使用lora训练, 38 | 2. none:全量参数训练; 39 | 3. freeze_vision:只冻结vision_tower进行训练 40 | """ 41 | }, 42 | ) 43 | lora_r: int = field(default=8) 44 | lora_alpha: int = field(default=16) 45 | low_lr: Optional[float] = field( 46 | default=2e-5, 47 | metadata={ 48 | "help": "使用lora进行训练模型的lr" 49 | } 50 | ) 51 | high_lr: Optional[float] = field( 52 | default=1e-3,metadata={ 53 | "help": "multi_modal_projector层训练的lr" 54 | } 55 | ) 56 | 57 | 58 | @dataclass 59 | class DataArguments: 60 | build_data_from_web: bool = field( 61 | default=False, metadata={"help": "是否使用web获得数据"} 62 | ) 63 | data_path: str = field( 64 | default=None, metadata={"help": "Path to the training data."} 65 | ) 66 | web_host_ip: str = field(default="0.0.0.0", metadata={"help": "web端的数据ip"}) 67 | model_max_length: int = field(default=4096) 68 | json_name: str = field( 69 | default=None, metadata={"help": "Path to the training data json path."} 70 | ) 71 | image_folder: str = field( 72 | default=None, metadata={"help": "Path to the training data images path."} 73 | ) 74 | gpu_nums: int = field( 75 | default=8, metadata={"help": "Number of GPUs to use for training."} 76 | ) 77 | 78 | @dataclass 79 | class CustomTrainingArguments(TrainingArguments): 80 | gradient_accumulation_steps: int = field( 81 | default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."} 82 | 83 | warmup_steps: int = field( 84 | default=190, 85 | metadata={"help": "Number of warmup steps for the learning rate scheduler."} 86 | ) 87 | warmup_ratio: float = field( 88 | default=0.1, 89 | metadata={"help": "Number of warmup ratio for the learning rate scheduler."} 90 | ) 91 | learning_rate: float = field( 92 | default=5e-5, metadata={"help": "The initial learning rate for the optimizer."} 93 | ) 94 | 95 | def load_model_processor(modelargs: ModelArguments): 96 | model = LlavaForConditionalGeneration.from_pretrained( 97 | modelargs.model_name_or_path, 98 | torch_dtype=torch.bfloat16, 99 | # low_cpu_mem_usage=True, 100 | ) 101 | processor = LlavaProcessor.from_pretrained(modelargs.model_name_or_path) 102 | 103 | if modelargs.train_type == "use_lora": 104 | logging.warning("Loading model to Lora") 105 | 106 | from peft import LoraConfig, get_peft_model 107 | 108 | LORA_R = modelargs.lora_r 109 | LORA_ALPHA = modelargs.lora_alpha 110 | LORA_DROPOUT = 0.05 111 | TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] 112 | 113 | config = LoraConfig( 114 | r=LORA_R, 115 | lora_alpha=LORA_ALPHA, 116 | target_modules=TARGET_MODULES, 117 | lora_dropout=LORA_DROPOUT, 118 | bias="none", 119 | task_type="CAUSAL_LM", 120 | modules_to_save=["multi_modal_projector"], 121 | ) 122 | model = get_peft_model(model, config) 123 | high_lr = modelargs.high_lr 124 | low_lr = modelargs.low_lr 125 | optimizer_grouped_parameters = [ 126 | { 127 | "params": [ 128 | p 129 | for n, p in model.named_parameters() 130 | if any(nd in n for nd in config.modules_to_save) 131 | ], 132 | "lr": high_lr, 133 | }, 134 | { 135 | "params": [ 136 | p 137 | for n, p in model.named_parameters() 138 | if not any(nd in n for nd in config.modules_to_save) 139 | ], 140 | "lr": low_lr, 141 | }, 142 | ] 143 | 144 | elif modelargs.train_type == "none": 145 | logging.warning("使用全量参数进行训练") 146 | pass 147 | 148 | elif modelargs.train_type == "freeze_vision": 149 | logging.warning("冻结vision_tower网络层,剩下的网络权重进行训练") 150 | 151 | for param in model.vision_tower.parameters(): 152 | param.requires_grad = False 153 | 154 | print_trainable_parameters(model) 155 | return model, processor, optimizer_grouped_parameters 156 | 157 | 158 | def load_dataset_collator(processor, data_args: DataArguments): 159 | if data_args.build_data_from_web: 160 | llava_dataset = DatasetReceiveByWeb( 161 | data_args.web_host_ip, 162 | ) 163 | logging.warning("从网络层进行数据初始化") 164 | 165 | if len(llava_dataset) <= 0: 166 | raise ValueError("数据出现问题,无法进行web数据初始化") 167 | data_collator = TrainLlavaModelCollatorByWeb(processor, -100, model_max_length=data_args.model_max_length) 168 | else: 169 | 170 | llava_dataset = LlavaDataset( 171 | data_path = data_args.data_path, 172 | json_name = data_args.json_name, 173 | image_folder = data_args.image_folder 174 | ) 175 | data_collator = TrainLLavaModelCollator(processor, -100, model_max_length=data_args.model_max_length) 176 | 177 | return llava_dataset, data_collator 178 | 179 | class CustomTrainer(Trainer): 180 | def __init__(self, *args, **kwargs): 181 | self.optimizer_grouped_parameters = kwargs.pop("optimizer_grouped_parameters", None) 182 | super().__init__(*args, **kwargs) 183 | 184 | def create_optimizer_and_scheduler(self, num_training_steps): 185 | if self.optimizer_grouped_parameters is not None: 186 | if self.args.warmup_ratio is not None: 187 | warmup_steps = int(self.args.warmup_ratio * num_training_steps) 188 | else: 189 | warmup_steps = self.args.warmup_steps 190 | 191 | self.optimizer = torch.optim.AdamW(self.optimizer_grouped_parameters, lr=self.args.learning_rate) 192 | self.lr_scheduler = get_linear_schedule_with_warmup( 193 | self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps 194 | ) 195 | else: 196 | super().create_optimizer_and_scheduler(num_training_steps) 197 | 198 | def training_step(self, model: torch.nn.Module, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: 199 | loss = super().training_step(model, inputs) 200 | high_lr = self.optimizer.param_groups[0]['lr'] 201 | low_lr = self.optimizer.param_groups[1]['lr'] 202 | self.log({"high_lr": high_lr, "low_lr": low_lr}) 203 | return loss 204 | 205 | def train(): 206 | parser = transformers.HfArgumentParser( 207 | (ModelArguments, DataArguments, TrainingArguments) 208 | ) 209 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 210 | model, processor, optimizer_grouped_parameters = load_model_processor(model_args) 211 | train_dataset, data_collator = load_dataset_collator(processor, data_args) 212 | 213 | num_training_steps = len(train_dataset) * training_args.num_train_epochs // (training_args.per_device_train_batch_size * 214 | training_args.gradient_accumulation_steps * data_args.gpu_nums) 215 | 216 | if data_args.build_data_from_web: 217 | trainer = CustomTrainer( 218 | model=model, 219 | args=training_args, 220 | train_dataset=train_dataset, 221 | eval_dataset=None, 222 | data_collator=data_collator, 223 | optimizer_grouped_parameters=optimizer_grouped_parameters, 224 | ) 225 | else: 226 | trainer = CustomTrainer( 227 | model=model, 228 | args=training_args, 229 | train_dataset=train_dataset, 230 | eval_dataset=None, 231 | data_collator=data_collator, 232 | optimizer_grouped_parameters=optimizer_grouped_parameters, 233 | ) 234 | 235 | trainer.train() 236 | trainer.save_state() 237 | if model_args.train_type == "use_lora": 238 | model.save_pretrained(training_args.output_dir) 239 | else: 240 | trainer.save_model(output_dir=training_args.output_dir) 241 | 242 | if __name__ == "__main__": 243 | logging.basicConfig( 244 | format="%(asctime)s %(levelname)s [%(name)s] %(message)s", 245 | level=logging.INFO, 246 | datefmt="%Y-%m-%d %H:%M:%S", 247 | ) 248 | train() 249 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 模型 llava-Qwen2-7B-Instruct-Chinese-CLIP 增强中文文字识别能力和表情包内涵识别能力,达到gpt4o、claude-3.5-sonnet的水平! 2 | logo 3 | 4 | ### 更新日志 5 | [24/07/22] 更新《llava-Qwen2-7B-Instruct-Chinese-CLIP训练手册》,总结了最近尝试的不同训练组合及其图像识别效果。重点描述以下两种组合: 6 | 7 | 1. 实验编号13: 8 | - 模型:Qwen2-7B-Instruct 9 | - 视觉编码器:OFA-Sys/chinese-clip-vit-huge-pathc14 10 | - 数据集:数据集v1改进版(在原v1基础上,增加了Claude 3.5 Sonnet模型对meme和ChineseBQB的识别数据) 11 | - 效果:表情包问答输出效果最佳,但纯文字图片识别效果一般 12 | 13 | 2. 实验编号12(两阶段训练):
14 | 第一阶段: 15 | - 模型:Qwen2-7B-Instruct 16 | - 视觉编码器:OFA-Sys/chinese-clip-vit-large-pathc14-336px 17 | - 数据集:数据集v2 18 | 19 | 第二阶段: 20 | - 使用数据集v1改进版对第一阶段模型进行微调 21 | - 效果:表情包问答输出效果良好,纯文字识别效果最佳。但由于数据集中的文字大多为单行,对多行文字图像的识别仅能返回第一行结果。 22 | 23 | 详细信息请参阅《llava-Qwen2-7B-Instruct-Chinese-CLIP训练手册.xlsx》。 24 | 25 | [24/07/19] 上传模型v2:https://huggingface.co/REILX/v1llava-Qwen2-7B-Instruct-Chinese-CLIP-v2, 对比上一代模型REILX/llava-Qwen2-7B-Instruct-Chinese-CLIP,训练数据扩大至四种中文数据集,图片总数扩大47倍,文字总数扩大5倍。 26 | 27 | [24/07/09] 上传模型v1:https://huggingface.co/REILX/llava-Qwen2-7B-Instruct-Chinese-CLIP 28 | 29 | ### 模型结构
30 | llava-Qwen2-7B-Instruct-Chinese-CLIP = Qwen/Qwen2-7B-Instruct + multi_modal_projector + OFA-Sys/chinese-clip-vit-large-patch14-336px
31 | 32 | ### 微调模块 33 | - vision_tower和language_model的q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj部分进行lora训练
34 | - mmp层全量训练
35 | 36 | ### 微调
37 | #### v1 38 | - lora_r=32,lora_alpha=64,num_train_epochs=5,per_device_train_batch_size=1,gradient_accumulation_steps=8,high_lr=1e-3,low_lr=2e-5,model_max_length=2048.
39 | - 设备:8*A800
40 | - 训练时长:5小时12分钟 41 | 42 | #### v2 43 | - lora_r=32,lora_alpha=64,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,high_lr=5e-4,low_lr=1e-5,model_max_length=2048.
44 | - 设备:8*A800
45 | - 训练时长:68小时06分钟 46 | 47 | ### 数据集
48 | #### v1 49 | - 使用gemini-1.5-pro, gemini-1.5-flash, yi-vision, gpt4o,claude-3.5-sonnet模型描述emo-visual-data和ChineseBQB数据集。
50 | 文本描述信息通过[text-description-of-the-meme](https://huggingface.co/datasets/REILX/text-description-of-the-meme) 下载
51 | 图像可通过[emo-visual-data](https://github.com/LLM-Red-Team/emo-visual-data), [ChineseBQB](https://github.com/zhaoolee/ChineseBQB)下载
52 | 图片数据总量1.8G,约10835张中文表情包图片。文字总量42Mb,约24332个图像文本对描述信息。 53 | 54 | #### v2 55 | - 使用gemini-1.5-pro, gemini-1.5-flash, yi-vision, gpt4o,claude-3.5-sonnet模型描述emo-visual-data和ChineseBQB数据集。
56 | 文本描述信息通过[text-description-of-the-meme](https://huggingface.co/datasets/REILX/text-description-of-the-meme) 下载
57 | 图像可通过[emo-visual-data](https://github.com/LLM-Red-Team/emo-visual-data), [ChineseBQB](https://github.com/zhaoolee/ChineseBQB)下载
58 | 图片数据总量1.8G,约10835张中文表情包图片。文字总量42Mb,约24332个图像文本对描述信息。 59 | - [priyank-m/chinese_text_recognition](https://huggingface.co/datasets/priyank-m/chinese_text_recognition)
60 | 图片数据总量2.0Gb,约500000张图片。文字总量207Mb,约500000个图像文本对描述信息。 61 | - [SWHL/ChineseOCRBench](https://huggingface.co/datasets/SWHL/ChineseOCRBench)
62 | 图片数据总量134Mb,约3410张图片。文字总量1.3Mb,约3410个图像文本对描述信息。 63 | - [fly0331/ChineseTest](https://huggingface.co/datasets/fly0331/ChineseTest)
64 | 图片数据总量530Mb,约6247张图片。文字总量5.4Mb,约6247个图像文本对描述信息。 65 | 66 | 67 | ### 效果展示
68 | 以下测试结果显示模型能识别图像中的文字信息,且能正确识别表情包想要表达的内涵。对比[REILX/llava-1.5-7b-hf-meme-lora](https://huggingface.co/REILX/llava-1.5-7b-hf-meme-lora)模型中也测试了原始llava-1.5-7b-hf模型的输出,模型无法正确识别图像中的文本信息。
69 | **以下9张图片为llava-Qwen2-7B-Instruct-Chinese-CLIP-v2模型的识别效果**
70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | **以下6张图片为llava-Qwen2-7B-Instruct-Chinese-CLIP模型的识别效果**
81 | 82 | 83 | 84 | 85 | 86 | 87 |
88 | 89 | **以下6张图片为Qwen/Qwen2-7B-Instruct + multi_modal_projector + openai/clip-vit-large-patch14-336模型训练后的识别效果**
90 | 91 | 92 | 93 | 94 | 95 | 96 |
97 | 98 | **以下3张图为gpt4o的识别效果**
99 | 100 | 101 | 102 | 103 | ## 项目代码:
104 | 基础模型生成代码和训练代码可参考项目:https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava
105 | 本项目基于原代码进行优化,主要改进包括: 106 | 1. **新增warmup功能,并针对不同模块采用差异化学习率,加速了multi_modal_projector的收敛速度。** 这一改进有效提升了模型训练效率。 107 | 2. **针对训练过程中出现的报错情况,对部分代码进行了优化,增强了代码稳定性。** 108 | 3. **用户只需修改run.sh中的参数,即可直接执行训练。** 简化了操作流程,提升了用户体验。 109 | 110 | 合并模型代码,合并模型之后将add_tokens.json,merge.txt,preprocessor_config.json,specital_token_map.json,tokenizer.json,vocab.json文件复制到"/保存的完整模型路径"。 111 | ```python 112 | import torch 113 | from peft import PeftModel, LoraConfig 114 | from transformers import LlavaForConditionalGeneration 115 | model_name = "/替换为你的基础模型路径" 116 | LORA_R = 32 117 | LORA_ALPHA = 64 118 | LORA_DROPOUT = 0.05 119 | TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] 120 | lora_config = LoraConfig( 121 | r=LORA_R, 122 | lora_alpha=LORA_ALPHA, 123 | target_modules=TARGET_MODULES, 124 | lora_dropout=LORA_DROPOUT, 125 | bias="none", 126 | task_type="CAUSAL_LM", 127 | modules_to_save=["multi_modal_projector"], 128 | ) 129 | model = LlavaForConditionalGeneration.from_pretrained(model_name) 130 | model = PeftModel.from_pretrained(model, "/替换为你的lora模型路径", config=lora_config, adapter_name='lora') 131 | 132 | model.cpu() 133 | model.eval() 134 | base_model = model.get_base_model() 135 | base_model.eval() 136 | model.merge_and_unload() 137 | base_model.save_pretrained("/保存的完整模型路径") 138 | ``` 139 | 140 | 推理代码 141 | ```python 142 | from transformers import LlavaForConditionalGeneration, AutoProcessor 143 | import torch 144 | from PIL import Image 145 | 146 | raw_model_name_or_path = "/保存的完整模型路径" 147 | model = LlavaForConditionalGeneration.from_pretrained(raw_model_name_or_path, device_map="cuda:0", torch_dtype=torch.bfloat16) 148 | processor = AutoProcessor.from_pretrained(raw_model_name_or_path) 149 | model.eval() 150 | 151 | def build_model_input(model, processor): 152 | messages = [ 153 | {"role": "system", "content": "You are a helpful assistant."}, 154 | {"role": "user", "content": "\n 使用中文描述图片中的信息"} 155 | ] 156 | prompt = processor.tokenizer.apply_chat_template( 157 | messages, tokenize=False, add_generation_prompt=True 158 | ) 159 | image = Image.open("01.PNG") 160 | inputs = processor(text=prompt, images=image, return_tensors="pt", return_token_type_ids=False) 161 | 162 | for tk in inputs.keys(): 163 | inputs[tk] = inputs[tk].to(model.device) 164 | generate_ids = model.generate(**inputs, max_new_tokens=200) 165 | 166 | generate_ids = [ 167 | oid[len(iids):] for oid, iids in zip(generate_ids, inputs.input_ids) 168 | ] 169 | gen_text = processor.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0] 170 | return gen_text 171 | build_model_input(model, processor) 172 | ``` 173 | 174 | ### TODO 175 | - [x] llava项目基于二阶段训练,先进行Pretraining,再进行Fine-tuning。Pretraining训练的部分:multi_modal_projector, 176 | Fine-tuning训练的部分:language_model + multi_modal_projector。本项目后续基于上述思路进行优化,比较两者差异。 177 | (7/22已完成,详情参考《llava-Qwen2-7B-Instruct-Chinese-CLIP训练手册》) 178 | - [ ] 对llava-Qwen2-7B-Instruct-Chinese-CLIP模型进行更全面测评。 179 | 180 | ### 致谢 181 | 182 | 本项目受益于 [train_llava](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava)、[ChineseBQB](https://github.com/zhaoolee/ChineseBQB)、[emo-visual-data](https://github.com/LLM-Red-Team/emo-visual-data),感谢以上诸位作者的付出。 183 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------