├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── docs ├── LLaVA-NeXT-Interleave.md ├── LLaVA-NeXT-Video.md ├── LLaVA-NeXT-Video_0716.md ├── LLaVA-NeXT.md ├── LLaVA_OneVision.md ├── LLaVA_OneVision_Chat.md ├── LLaVA_OneVision_Tutorials.ipynb ├── LLaVA_Video_1003.md ├── README.md ├── jobs.mp4 ├── onevision_trial.py └── ov_chat_images │ ├── chat_results.png │ ├── example1_tree.png │ └── example2_dog.jpg ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── evaluate_interleave.py │ └── model_vqa.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_gemma.py │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mixtral.py │ │ ├── llava_mpt.py │ │ ├── llava_qwen.py │ │ ├── llava_qwen_moe.py │ │ └── modeling_llama.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── dev_eva_clip │ │ │ ├── eva_clip │ │ │ │ ├── __init__.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── constants.py │ │ │ │ ├── eva_vit_model.py │ │ │ │ ├── factory.py │ │ │ │ ├── hf_configs.py │ │ │ │ ├── hf_model.py │ │ │ │ ├── loss.py │ │ │ │ ├── model.py │ │ │ │ ├── model_configs │ │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ │ │ ├── modified_resnet.py │ │ │ │ ├── openai.py │ │ │ │ ├── pretrained.py │ │ │ │ ├── rope.py │ │ │ │ ├── timm_model.py │ │ │ │ ├── tokenizer.py │ │ │ │ ├── transform.py │ │ │ │ ├── transformer.py │ │ │ │ └── utils.py │ │ │ └── eva_vit.py │ │ ├── eva_clip │ │ │ ├── eva_clip_encoder.py │ │ │ ├── eva_clip_processors.py │ │ │ ├── eva_vit.py │ │ │ ├── factory.py │ │ │ └── model_configs │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ ├── hf_vision.py │ │ ├── imagebind.py │ │ ├── mlcd │ │ │ └── vit_rope2d_hf.py │ │ ├── mlcd_encoder.py │ │ ├── open_clip_encoder.py │ │ └── siglip_encoder.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── masked_drop.py │ │ ├── perceiver.py │ │ ├── qformer.py │ │ └── spatial_pool.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_multi_image.py │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── llava_trainer_eval.py │ ├── train.py │ ├── train_dpo.py │ └── train_mem.py └── utils.py ├── playground ├── 2d_hist.py ├── data_checker.py ├── demo │ ├── video_demo.py │ └── xU25MMA2N4aVtYay.mp4 ├── equal_splitter.py ├── remove_mid_ckpt.py ├── sgl_llava_inference_multinode.py └── upload_data.py ├── predict.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── archived │ ├── convert_gqa_for_eval.py │ ├── convert_mmvet_for_eval.py │ ├── convert_sqa_to_llava.py │ ├── convert_sqa_to_llava_base_prompt.py │ ├── convert_vizwiz_for_submission.py │ ├── convert_vqav2_for_submission.py │ ├── data_info.py │ ├── dpo_data_info.py │ ├── entry_cmd.sh │ ├── finetune.sh │ ├── finetune_1.5.sh │ ├── finetune_full_schedule.sh │ ├── finetune_lora.sh │ ├── finetune_mixtral.sh │ ├── finetune_mixtral_1.5.sh │ ├── finetune_mixtral_1.6_336px_anyres.sh │ ├── finetune_mixtral_1.6_336px_anyres_freeze_vision.sh │ ├── finetune_mixtral_1.6_336px_anyres_lmms_eval.sh │ ├── finetune_mixtral_copy.sh │ ├── finetune_qlora.sh │ ├── finetune_sqa.sh │ ├── merge_lora_weights.py │ ├── pretrain.sh │ ├── quick_check.py │ ├── sqa_eval_batch.sh │ └── sqa_eval_gather.sh ├── interleave │ ├── eval_all.sh │ ├── eval_interleave_3d.sh │ └── eval_multiprocess.sh ├── qwen.py ├── summarize_data.py ├── train │ ├── README.md │ ├── direct_finetune_clip.sh │ ├── direct_finetune_siglip_a4.sh │ ├── dpo.sh │ ├── dpo_ov7b.sh │ ├── finetune_ov.sh │ ├── finetune_si.sh │ ├── mid_stage.yaml │ ├── onevision.yaml │ ├── pretrain_clip.sh │ ├── pretrain_siglip.sh │ └── single_image.yaml ├── video │ ├── demo │ │ └── video_demo.sh │ ├── eval │ │ ├── activitynet_eval.sh │ │ ├── video_chatgpt_benchmark_eval_shard.sh │ │ ├── video_description_from_t2v.sh │ │ ├── video_detail_description_eval_only.sh │ │ └── video_detail_description_eval_shard.sh │ └── train │ │ ├── SO400M_Qwen2_72B_ov_to_video_am9.sh │ │ ├── SO400M_Qwen2_7B_ov_to_video_am9.sh │ │ └── exp.yaml ├── zero2.json ├── zero2_fused_adamw.json ├── zero2_offload.json ├── zero3.json ├── zero3_offload.json └── zero3pp.json └── trl ├── __init__.py ├── core.py ├── environment ├── __init__.py └── base_environment.py ├── extras ├── __init__.py ├── best_of_n_sampler.py └── dataset_formatting.py ├── import_utils.py ├── models ├── __init__.py ├── modeling_base.py ├── modeling_sd_base.py ├── modeling_value_head.py └── utils.py └── trainer ├── __init__.py ├── base.py ├── ddpo_config.py ├── ddpo_trainer.py ├── dpo_trainer.py ├── iterative_sft_trainer.py ├── model_config.py ├── ppo_config.py ├── ppo_trainer.py ├── reward_config.py ├── reward_trainer.py ├── sft_trainer.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | # Exclude some weights 20 | /openai 21 | /liuhaotian 22 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | # Unix-style newlines with a newline ending every file 4 | [*] 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.{py,json}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | # 2 space indentation 16 | [*.{md,sh,yaml,yml}] 17 | indent_style = space 18 | indent_size = 2 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | # *.json 11 | # *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | # Editor 16 | .idea 17 | *.swp 18 | .vscode 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | llavavid 25 | 26 | checkpoints 27 | project_checkpoints 28 | debug_checkpoints 29 | playground/data 30 | playground/cc3m_llava34b_cap 31 | ckpts* 32 | 33 | .ipynb_checkpoints 34 | chunyl_scripts 35 | *.ipynb 36 | 37 | # DevContainer 38 | !.devcontainer/* 39 | 40 | # Demo 41 | serve_images/ 42 | notebooks/ 43 | logs 44 | scripts/dist_* 45 | logs/ 46 | submissions/ 47 | cn_scripts/ 48 | internal_project_checkpoints/ 49 | work_dirs 50 | scripts/i18n/* 51 | playground/.nfs028b000000010add00000001 52 | HIP 53 | playground/.nfs028b0000017bff2c00000012 54 | scripts/qwen 55 | scripts/vicuna 56 | scripts/mistral 57 | scripts/baseline_rep 58 | scripts/cn_boli01_hl 59 | scripts/cn_boli01_lf 60 | scripts/cn_lf 61 | scripts/cn_lq 62 | scripts/cn_yg 63 | scripts/cn_yg_hao 64 | scripts/eva_encoder 65 | scripts/i18n 66 | scripts/i18n_higher_res 67 | scripts/multi-images 68 | scratchpad 69 | build/ 70 | playground/*.json 71 | mlx_configs/ 72 | data_processing/ 73 | # demo/ 74 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /docs/LLaVA-NeXT-Interleave.md: -------------------------------------------------------------------------------- 1 | 2 | # LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models 3 | 4 | ## Contents 5 | - [Demo](#demo) 6 | - [Evaluation](#evaluation) 7 | 8 | ## Demo 9 | 10 | > make sure you installed the LLaVA-NeXT model files via outside REAME.md 11 | 12 | 1. **Example model:** `lmms-lab/llava-next-interleave-7b` 13 | 14 | 15 | To run a demo, execute: 16 | ```bash 17 | # If you find any bug when running the demo, please make sure checkpoint path contains 'qwen'. 18 | # You can try command like 'mv llava-next-interleave-7b llava-next-interleave-qwen-7b' 19 | python playground/demo/interleave_demo.py --model_path path/to/ckpt 20 | ``` 21 | 22 | ## Evaluation 23 | 24 | ### Preparation 25 | 26 | Please download the evaluation data and its metadata from the following links: 27 | 28 | 1. **llava-interleave-bench:** [here](https://huggingface.co/datasets/lmms-lab/llava-interleave-bench). 29 | 30 | Unzip eval_images.zip and there are Split1 and Split2 in it. 31 | Organize the downloaded data into the following structure: 32 | ``` 33 | 34 | interleave_data 35 | ├── Split1 36 | │ ├── ... 37 | │ └── ... 38 | | 39 | ├── Split2 40 | | ├── ... 41 | │ └── ... 42 | ├── multi_image_in_domain.json 43 | ├── multi_image_out_domain.json 44 | └── multi_view_in_domain.json 45 | ``` 46 | 47 | ### Inference and Evaluation 48 | Example: 49 | Please first edit /path/to/ckpt to the path of checkpoint, /path/to/images to the path of "interleave_data" in scripts/interleave/eval_all.sh and then run 50 | ```bash 51 | bash scripts/interleave/eval_all.sh 52 | ``` 53 | 54 | -------------------------------------------------------------------------------- /docs/LLaVA-NeXT-Video.md: -------------------------------------------------------------------------------- 1 | 2 | # LLaVA-NeXT: A Strong Zero-shot Video Understanding Model 3 | 4 | ## Contents 5 | - [Demo](#demo) 6 | - [Evaluation](#evaluation) 7 | 8 | ## Demo 9 | 10 | > make sure you installed the LLaVA-NeXT model files via outside REAME.md 11 | 12 | 1. **Example model:** `lmms-lab/LLaVA-NeXT-Video-7B-DPO` 13 | 14 | 2. **Prompt mode:** `vicuna_v1` (use `mistral_direct` for `lmms-lab/LLaVA-NeXT-Video-34B-DPO`) 15 | 16 | 3. **Sampled frames:** `32` (Defines how many frames to sample from the video.) 17 | 18 | 4. **Spatial pooling stride:** `2` (With original tokens for one frame at 24x24, if stride=2, then the tokens for one frame are 12x12.) 19 | 20 | 5. **Spatial pooling mode:** `average` (Options: `average`, `max`.) 21 | 22 | 6. **Local video path:** `./data/llava_video/video-chatgpt/evaluation/Test_Videos/v_Lf_7RurLgp0.mp4` 23 | 24 | To run a demo, execute: 25 | ```bash 26 | bash scripts/video/demo/video_demo.sh ${Example model} ${Prompt mode} ${Sampled frames} ${Spatial pooling stride} ${Spatial pooling mode} grid True ${Video path at local} 27 | ``` 28 | Example: 29 | ```bash 30 | bash scripts/video/demo/video_demo.sh lmms-lab/LLaVA-NeXT-Video-7B-DPO vicuna_v1 32 2 average no_token True playground/demo/xU25MMA2N4aVtYay.mp4 31 | ``` 32 | 33 | **IMPORTANT** Please refer to [Latest video model](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/inference/docs/LLaVA-NeXT-Video_0716.md) for the runnning of the latest model. 34 | 35 | ## Evaluation 36 | 37 | ### Preparation 38 | 39 | Please download the evaluation data and its metadata from the following links: 40 | 41 | 1. **video-chatgpt:** [here](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md#video-based-generative-performance-benchmarking). 42 | 2. **video_detail_description:** [here](https://mbzuaiac-my.sharepoint.com/personal/hanoona_bangalath_mbzuai_ac_ae/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FQuantitative%5FEvaluation%2Fbenchamarking%2FTest%5FHuman%5FAnnotated%5FCaptions%2Ezip&parent=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FQuantitative%5FEvaluation%2Fbenchamarking&ga=1). 43 | 3. **activity_qa:** [here](https://mbzuaiac-my.sharepoint.com/personal/hanoona_bangalath_mbzuai_ac_ae/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FData%2FActivityNet%5FTest%2D1%2D3%5Fvideos%2Ezip&parent=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FData&ga=1) and [here](https://github.com/MILVLG/activitynet-qa/tree/master/dataset). 44 | 45 | Organize the downloaded data into the following structure: 46 | ``` 47 | LLaVA-NeXT 48 | ├── llava 49 | ├── scripts 50 | └── data 51 | └── llava_video 52 | ├── video-chatgpt 53 | │ ├── Test_Videos 54 | │ ├── consistency_qa.json 55 | │ ├── consistency_qa_test.json 56 | │ ├── consistency_qa_train.json 57 | ├── video_detail_description 58 | │ └── Test_Human_Annotated_Captions 59 | └── ActivityNet-QA 60 | ├── all_test 61 | ├── test_a.json 62 | └── test_b.json 63 | ``` 64 | 65 | ### Inference and Evaluation 66 | 67 | Example for video detail description evaluation (additional scripts are available in `scripts/eval`): 68 | ```bash 69 | bash scripts/video/eval/video_detail_description_eval_shard.sh ${Example model} ${Prompt mode} ${Sampled frames} ${Spatial pooling stride} True 8 70 | ``` 71 | Example: 72 | ```bash 73 | bash scripts/eval/video_detail_description_eval_shard.sh liuhaotian/llava-v1.6-vicuna-7b vicuna_v1 32 2 True 8 74 | ``` 75 | 76 | ### GPT Evaluation Example (Optional if the above step is completed) 77 | 78 | Assuming you have `pred.json` (model-generated predictions) for model `llava-v1.6-vicuna-7b` at `./work_dirs/eval_video_detail_description/llava-v1.6-vicuna-7b_vicuna_v1_frames_32_stride_2`: 79 | ```bash 80 | bash scripts/video/eval/video_description_eval_only.sh llava-v1.6-vicuna-7b_vicuna_v1_frames_32_stride_2 81 | ``` 82 | -------------------------------------------------------------------------------- /docs/LLaVA-NeXT-Video_0716.md: -------------------------------------------------------------------------------- 1 | ## LLaVA-NeXT-Video is upgraded 🚀 2 | 3 | In our [LLaVA-Video blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/) released this April, we shared two key observations: 4 | - 🎬 AnyRes provides a shared and flexible representation between images and videos, and thus accommodates capability transfer between the two most common vision signals. Therefore, stronger image LMMs can naturally lead to stronger zero-shot video LMMs. 5 | - 🗂️ There is a lack of high-quality language-video data, including video instruction-following data, and thus naive tuning on existing public data at that time results in performance degradation. Therefore, there is an urgent need to build high-quality video captions and QA datasets to train LMMs for improved video performance. 6 | 7 | Based on the insights, the new LLaVA-NeXT-Video in this release improves from two aspects: 8 | 9 | - 🎬 A stronger image LMMs ([LLaVA-NeXT-32B-Qwen](https://huggingface.co/lmms-lab/llava-next-qwen-32b)), which is built by initializing from Qwen-1.5 32B LLM. We further initialize our video training from this image checkpoint. 10 | - 🗂️ A new high-quality video dataset with 830k samples. It is combined with LLaVA-1.6 image training data, and applying the same image-video mixed training procedure leads to the new video model. 11 | The new model achieves the best open-source performance in several video benchmarks including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard). 12 | 13 | ### Resources 14 | - **Model Card**: [LLaVA-NeXT-Video-32B-Qwen on Hugging Face](https://huggingface.co/lmms-lab/LLaVA-NeXT-Video-32B-Qwen) 15 | - **Inference Script**: 16 | ```bash 17 | bash scripts/video/demo/video_demo.sh lmms-lab/LLaVA-NeXT-Video-32B-Qwen qwen_1_5 32 2 average grid True playground/demo/xU25MMA2N4aVtYay.mp4 18 | ``` 19 | 20 | ### Evaluation Results 21 | | Model | NextQA-MC | video-mme(overall) | | Egochema | Perception Test (val) | 22 | |-----------------------------|-----------|--------------------|--------|----------|------------------------| 23 | | | | w/o subs | w subs | | | 24 | | **Proprietary** | | | | | | 25 | | GPT-4o | - | 71.9 | 77.2 | 72.2 | - | 26 | | Gemini 1.5 Pro | - | 75.0 | 81.3 | 72.2 | - | 27 | | **Open-Source** | | | | | | 28 | | VideoLLaMA 2 (8x7B) | 76.3* | 47.9 | 50.3 | 53.3 | 51.2* | 29 | | VILA-1.5-34B | 67.89* | 60.1 | 61.1 | 58.04* | 54 | 30 | | LLaVA-NeXT-Video (Qwen-32B) | 77.31 | 60.2 | 63.0 | 60.85 | 59.38 | 31 | 32 | _*Results are reproduced by [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). Please refer to the lmms-eval to reproduce the results._ 33 | 34 | ### Citations 35 | ```bibtex 36 | @misc{zhang2024llavanextvideo, 37 | title={LLaVA-NeXT: A Strong Zero-shot Video Understanding Model}, 38 | url={https://llava-vl.github.io/blog/2024-04-30-llava-next-video/}, 39 | author={Zhang, Yuanhan and Li, Bo and Liu, haotian and Lee, Yong jae and Gui, Liangke and Fu, Di and Feng, Jiashi and Liu, Ziwei and Li, Chunyuan}, 40 | month={April}, 41 | year={2024} 42 | } 43 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # LLaVA-NeXT Documentation 2 | 3 | Welcome to the LLaVA-NeXT documentation. This guide provides an overview of the different components and features of LLaVA-NeXT. Please refer to the following documents for detailed information on specific topics: 4 | 5 | 1. [LLaVA OneVision](LLaVA_OneVision.md): Learn about the most advanced and unified version: LLaVA OneVision. 6 | - [LLaVA OneVision: Inference Tutorials](LLaVA_OneVision_Tutorials.ipynb): Learn how to use LLaVA OneVision for inference. 7 | - [LLaVA Onevision Chat](LLaVA_OneVision_Chat.md): Improving Chat with Preference Learning 8 | 9 | 2. [LLaVA-NeXT Interleave](LLaVA-NeXT-Interleave.md): Explore the interleaved training approach used in LLaVA-NeXT. 10 | 11 | 3. [LLaVA-NeXT Video (0716)](LLaVA-NeXT-Video_0716.md): Discover the video processing capabilities of LLaVA-NeXT (version 0716). 12 | 13 | 4. [LLaVA-NeXT Video](LLaVA-NeXT-Video.md): Get information about the latest video processing features in LLaVA-NeXT. 14 | 15 | 5. [LLaVA-NeXT Overview](LLaVA-NeXT.md): Read a comprehensive overview of the LLaVA-NeXT project, including its architecture, features, and capabilities. 16 | 17 | These documents provide in-depth information on various aspects of LLaVA-NeXT. Please refer to them for detailed explanations, implementation details, and usage instructions. -------------------------------------------------------------------------------- /docs/jobs.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/docs/jobs.mp4 -------------------------------------------------------------------------------- /docs/ov_chat_images/chat_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/docs/ov_chat_images/chat_results.png -------------------------------------------------------------------------------- /docs/ov_chat_images/example1_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/docs/ov_chat_images/example1_tree.png -------------------------------------------------------------------------------- /docs/ov_chat_images/example2_dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/docs/ov_chat_images/example2_dog.jpg -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", 7 | "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", 8 | # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", 9 | # Add other models as needed 10 | } 11 | 12 | for model_name, model_classes in AVAILABLE_MODELS.items(): 13 | try: 14 | exec(f"from .language_model.{model_name} import {model_classes}") 15 | except Exception as e: 16 | print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}") 17 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model import * 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | from .mlcd_encoder import MLCDVisionTower, MLCDVisionTowerS2 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | elif "mlcd-vit-bigG-patch14" in vision_tower: 31 | if use_s2: 32 | return MLCDVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 33 | else: 34 | return MLCDVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 35 | 36 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 37 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 38 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 39 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 40 | 41 | raise ValueError(f"Unknown vision tower: {vision_tower}") 42 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 8 | from .tokenizer import SimpleTokenizer, tokenize 9 | from .transform import image_transform 10 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings", 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings", 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens", 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings", 54 | }, 55 | "pooler": "mean_pooler", 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == "min" else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert("RGB") 40 | 41 | 42 | # class CatGen(nn.Module): 43 | # def __init__(self, num=4): 44 | # self.num = num 45 | # def mixgen_batch(image, text): 46 | # batch_size = image.shape[0] 47 | # index = np.random.permutation(batch_size) 48 | 49 | # cat_images = [] 50 | # for i in range(batch_size): 51 | # # image mixup 52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 53 | # # text concat 54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 55 | # text = torch.stack(text) 56 | # return image, text 57 | 58 | 59 | def image_transform( 60 | image_size: int, 61 | is_train: bool, 62 | mean: Optional[Tuple[float, ...]] = None, 63 | std: Optional[Tuple[float, ...]] = None, 64 | resize_longest_max: bool = False, 65 | fill_color: int = 0, 66 | ): 67 | mean = mean or OPENAI_DATASET_MEAN 68 | if not isinstance(mean, (list, tuple)): 69 | mean = (mean,) * 3 70 | 71 | std = std or OPENAI_DATASET_STD 72 | if not isinstance(std, (list, tuple)): 73 | std = (std,) * 3 74 | 75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 77 | image_size = image_size[0] 78 | 79 | normalize = Normalize(mean=mean, std=std) 80 | if is_train: 81 | return Compose( 82 | [ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ] 88 | ) 89 | else: 90 | if resize_longest_max: 91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 92 | else: 93 | transforms = [ 94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 95 | CenterCrop(image_size), 96 | ] 97 | transforms.extend( 98 | [ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ] 103 | ) 104 | return Compose(transforms) 105 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import EVAEncoderWrapper 6 | from .factory import list_models, add_model_config, get_model_config 7 | 8 | from llava.utils import rank0_print 9 | 10 | 11 | class EvaClipVisionTower(nn.Module): 12 | def __init__(self, vision_tower, args, delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | self.vision_tower_name = vision_tower 17 | self.vision_tower_pretrained = args.vision_tower_pretrained 18 | self.config = get_model_config(vision_tower) 19 | 20 | if not delay_load: 21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") 22 | self.load_model() 23 | elif getattr(args, "unfreeze_mm_vision_tower", False): 24 | # TODO: better detector is needed. 25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 26 | self.load_model() 27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 29 | self.load_model() 30 | else: 31 | self.cfg_only = self.config 32 | 33 | def load_model(self, device_map=None): 34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}") 35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) 36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) 37 | rank0_print(f"Loaded image processor: {self.image_processor}") 38 | self.vision_tower.requires_grad_(False) 39 | self.is_loaded = True 40 | 41 | def forward(self, images): 42 | if type(images) is list: 43 | image_features = [] 44 | for image in images: 45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 49 | 50 | return image_features 51 | 52 | @property 53 | def dtype(self): 54 | return self.vision_tower.dtype 55 | 56 | @property 57 | def device(self): 58 | return self.vision_tower.device 59 | 60 | @property 61 | def hidden_size(self): 62 | return self.config["vision_cfg"]["width"] 63 | 64 | @property 65 | def num_patches(self): 66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 67 | 68 | @property 69 | def num_patches_per_side(self): 70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] 71 | 72 | @property 73 | def image_size(self): 74 | return self.config["vision_cfg"]["image_size"] 75 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | """ 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {"height": self.image_size, "width": self.image_size} 69 | 70 | @property 71 | def size(self): 72 | return {"shortest_edge": self.image_size} 73 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 13 | 14 | 15 | def _natural_key(string_): 16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 17 | 18 | 19 | def _rescan_model_configs(): 20 | global _MODEL_CONFIGS 21 | 22 | config_ext = (".json",) 23 | config_files = [] 24 | for config_path in _MODEL_CONFIG_PATHS: 25 | if config_path.is_file() and config_path.suffix in config_ext: 26 | config_files.append(config_path) 27 | elif config_path.is_dir(): 28 | for ext in config_ext: 29 | config_files.extend(config_path.glob(f"*{ext}")) 30 | 31 | for cf in config_files: 32 | with open(cf, "r", encoding="utf8") as f: 33 | model_cfg = json.load(f) 34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): 35 | _MODEL_CONFIGS[cf.stem] = model_cfg 36 | 37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 38 | 39 | 40 | _rescan_model_configs() # initial populate of model config registry 41 | 42 | 43 | def list_models(): 44 | """enumerate available model architectures based on config files""" 45 | return list(_MODEL_CONFIGS.keys()) 46 | 47 | 48 | def add_model_config(path): 49 | """add model config path or file and update registry""" 50 | if not isinstance(path, Path): 51 | path = Path(path) 52 | _MODEL_CONFIG_PATHS.append(path) 53 | _rescan_model_configs() 54 | 55 | 56 | def get_model_config(model_name): 57 | if model_name in _MODEL_CONFIGS: 58 | return deepcopy(_MODEL_CONFIGS[model_name]) 59 | else: 60 | return None 61 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) 21 | worker_addr = ret.json()["address"] 22 | print(f"worker_addr: {worker_addr}") 23 | 24 | if worker_addr == "": 25 | return 26 | 27 | conv = default_conversation.copy() 28 | conv.append_message(conv.roles[0], args.message) 29 | prompt = conv.get_prompt() 30 | 31 | headers = {"User-Agent": "LLaVA Client"} 32 | pload = { 33 | "model": args.model_name, 34 | "prompt": prompt, 35 | "max_new_tokens": args.max_new_tokens, 36 | "temperature": 0.7, 37 | "stop": conv.sep, 38 | } 39 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) 40 | 41 | print(prompt.replace(conv.sep, "\n"), end="") 42 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 43 | if chunk: 44 | data = json.loads(chunk.decode("utf-8")) 45 | output = data["text"].split(conv.sep)[-1] 46 | print(output, end="\r") 47 | print("") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 53 | parser.add_argument("--worker-address", type=str) 54 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 55 | parser.add_argument("--max-new-tokens", type=int, default=32) 56 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") 57 | args = parser.parse_args() 58 | 59 | main() 60 | -------------------------------------------------------------------------------- /llava/train/llava_trainer_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from llava.train.llava_trainer import LLaVATrainer 5 | 6 | 7 | class LLaVAEvalTrainer(LLaVATrainer): 8 | def evaluate(self, evaluate_args): 9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ 10 | --model {evaluate_args.model} \ 11 | --model_args {evaluate_args.model_args} \ 12 | --tasks {evaluate_args.task_names} \ 13 | --batch_size {evaluate_args.batch_size} \ 14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \ 15 | --output_path {evaluate_args.output_path}" 16 | if evaluate_args.limit: 17 | cmd += f" --limit {evaluate_args.limit}" 18 | if evaluate_args.num_fewshot: 19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}" 20 | if evaluate_args.gen_kwargs != "": 21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" 22 | if evaluate_args.log_samples: 23 | cmd += f" --log_samples" 24 | else: 25 | assert False, "Please log samples so that the result can be parsed" 26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True) 27 | try: 28 | result_file_index_start = results.stdout.index("Saved samples to ") 29 | result_file_index_end = results.stdout.index(f".json") 30 | result_file_index_start += len("Saved samples to ") 31 | file = results.stdout[result_file_index_start:result_file_index_end] 32 | except: 33 | result_file_index_start = results.stderr.index("Saved samples to ") 34 | result_file_index_end = results.stderr.index(f".json") 35 | result_file_index_start += len("Saved samples to ") 36 | file = results.stderr[result_file_index_start:result_file_index_end] 37 | file = file.split("/")[:-1] 38 | file = "/".join(file) + "/results.json" 39 | with open(file, "r") as f: 40 | lmms_eval_results = json.load(f) 41 | result_dict = {} 42 | tasks_list = evaluate_args.task_names.split(",") 43 | for task in tasks_list: 44 | task_results = lmms_eval_results["results"][task] 45 | for k, v in task_results.items(): 46 | if k != "alias" and "stderr" not in k: 47 | metric = k.split(",")[0] 48 | result_dict[f"{task}_{metric}"] = v 49 | return result_dict 50 | 51 | """def evaluate(self, evaluate_args): 52 | initialize_tasks() 53 | tasks_list = evaluate_args.task_names.split(",") 54 | result_dict = {} 55 | results = evaluator.simple_evaluate( 56 | model=evaluate_args.model, 57 | model_args=evaluate_args.model_args, 58 | tasks=tasks_list, 59 | num_fewshot=evaluate_args.num_fewshot, 60 | batch_size=evaluate_args.batch_size, 61 | device=evaluate_args.device, 62 | limit=evaluate_args.limit, 63 | check_integrity=evaluate_args.check_integrity, 64 | show_task_to_terminal=evaluate_args.show_task_to_terminal, 65 | log_samples=evaluate_args.log_samples, 66 | gen_kwargs=evaluate_args.gen_kwargs, 67 | cli_args=evaluate_args, 68 | ) 69 | for task in tasks_list: 70 | task_results = results["results"][task] 71 | for k,v in task_results.items(): 72 | if k != "alias" and "stderr" not in k: 73 | metric = k.split(",")[0] 74 | result_dict[f"{task}_{metric}"] = v 75 | 76 | return result_dict""" 77 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train() 5 | -------------------------------------------------------------------------------- /playground/demo/xU25MMA2N4aVtYay.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LLaVA-VL/LLaVA-NeXT/b42941ceba259d5df18f8df8193a3897296a0449/playground/demo/xU25MMA2N4aVtYay.mp4 -------------------------------------------------------------------------------- /playground/equal_splitter.py: -------------------------------------------------------------------------------- 1 | import json 2 | from math import ceil 3 | 4 | 5 | def split_json_file(input_file, n_splits): 6 | # Read the JSON file 7 | with open(input_file, "r") as file: 8 | data = json.load(file) 9 | 10 | # Calculate the size of each split 11 | total_items = len(data) 12 | items_per_split = ceil(total_items / n_splits) 13 | 14 | # Split the data and save into separate files 15 | for i in range(n_splits): 16 | start_index = i * items_per_split 17 | end_index = min((i + 1) * items_per_split, total_items) 18 | split_data = data[start_index:end_index] 19 | 20 | # Write the split data to a new JSON file 21 | with open(f"{input_file.split('.')[0]}_split_{i}.json", "w") as split_file: 22 | json.dump(split_data, split_file, indent=4) 23 | 24 | 25 | def main(): 26 | import argparse 27 | 28 | parser = argparse.ArgumentParser(description="Split a JSON file into multiple parts.") 29 | parser.add_argument("--input_file", type=str, help="The JSON file to split") 30 | parser.add_argument("--n_splits", type=int, help="The number of splits") 31 | 32 | args = parser.parse_args() 33 | 34 | split_json_file(args.input_file, args.n_splits) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /playground/remove_mid_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import glob 4 | 5 | 6 | def remove_checkpoints(directory, pattern): 7 | # Walk through the directory 8 | for root, dirs, files in os.walk(directory): 9 | # Use glob to find paths matching the pattern 10 | for file_path in glob.glob(os.path.join(root, pattern)): 11 | # Check if it is a directory 12 | if "llava-1.6-mistral-7b" in file_path: 13 | continue 14 | if os.path.isdir(file_path): 15 | # Remove the directory 16 | print(f"Removing {file_path}") 17 | input("Press Enter to continue...") 18 | shutil.rmtree(file_path) 19 | print(f"Removed directory: {file_path}") 20 | else: 21 | print(f"Removing {file_path}") 22 | input("Press Enter to continue...") 23 | # Remove the file 24 | os.remove(file_path) 25 | print(f"Removed file: {file_path}") 26 | 27 | 28 | # Directory containing the checkpoints 29 | directory = "/mnt/bn/vl-research/checkpoints/feng/" 30 | 31 | # Pattern to match in the file names 32 | pattern = "global_step*" 33 | 34 | # Call the function 35 | remove_checkpoints(directory, pattern) 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "llava" 10 | version = "1.7.0.dev0" 11 | description = "LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | 19 | [project.optional-dependencies] 20 | standalone = [ 21 | "shortuuid", 22 | "httpx==0.24.0", 23 | "einops", 24 | "ftfy", 25 | ] 26 | 27 | 28 | train = [ 29 | "llava[standalone]", 30 | "numpy==1.26.1", 31 | "open_clip_torch", 32 | "fastapi", 33 | "markdown2[all]", 34 | "numpy", 35 | "requests", 36 | "sentencepiece", 37 | "torch==2.1.2", 38 | "torchvision==0.16.2", 39 | "uvicorn", 40 | "wandb", 41 | "deepspeed==0.14.4", 42 | "peft==0.4.0", 43 | "accelerate>=0.29.1", 44 | "tokenizers~=0.15.2", 45 | "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4", 46 | "bitsandbytes==0.41.0", 47 | "scikit-learn==1.2.2", 48 | "sentencepiece~=0.1.99", 49 | "einops==0.6.1", 50 | "einops-exts==0.0.4", 51 | "gradio_client==0.2.9", 52 | "urllib3<=2.0.0", 53 | "datasets==2.16.1", 54 | "pydantic==1.10.8", 55 | "timm", 56 | "hf_transfer", 57 | "opencv-python", 58 | "av", 59 | "decord", 60 | "tyro", 61 | "scipy", 62 | ] 63 | 64 | [project.urls] 65 | "Homepage" = "https://llava-vl.github.io" 66 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 67 | 68 | [tool.setuptools.packages.find] 69 | include = ["llava*", "trl*"] 70 | exclude = [ 71 | "assets*", 72 | "benchmark*", 73 | "docs", 74 | "dist*", 75 | "playground*", 76 | "scripts*", 77 | "tests*", 78 | "checkpoints*", 79 | "project_checkpoints*", 80 | "debug_checkpoints*", 81 | "mlx_configs*", 82 | "wandb*", 83 | "notebooks*", 84 | ] 85 | 86 | [tool.wheel] 87 | exclude = [ 88 | "assets*", 89 | "benchmark*", 90 | "docs", 91 | "dist*", 92 | "playground*", 93 | "scripts*", 94 | "tests*", 95 | "checkpoints*", 96 | "project_checkpoints*", 97 | "debug_checkpoints*", 98 | "mlx_configs*", 99 | "wandb*", 100 | "notebooks*", 101 | ] 102 | -------------------------------------------------------------------------------- /scripts/archived/convert_gqa_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | all_answers = [] 11 | for line_idx, line in enumerate(open(args.src)): 12 | res = json.loads(line) 13 | question_id = res["question_id"] 14 | text = res["text"].rstrip(".").lower() 15 | all_answers.append({"questionId": question_id, "prediction": text}) 16 | 17 | with open(args.dst, "w") as f: 18 | json.dump(all_answers, f) 19 | -------------------------------------------------------------------------------- /scripts/archived/convert_mmvet_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | cur_result = {} 11 | 12 | for line in open(args.src): 13 | data = json.loads(line) 14 | qid = data["question_id"] 15 | cur_result[f"v1_{qid}"] = data["text"] 16 | 17 | with open(args.dst, "w") as f: 18 | json.dump(cur_result, f, indent=2) 19 | -------------------------------------------------------------------------------- /scripts/archived/convert_sqa_to_llava.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False) 13 | 14 | target_format = [] 15 | for prob_id, (input, output) in split_problems.items(): 16 | if input.startswith("Question: "): 17 | input = input.replace("Question: ", "") 18 | if output.startswith("Answer: "): 19 | output = output.replace("Answer: ", "") 20 | 21 | raw_prob_data = problems[prob_id] 22 | if raw_prob_data["image"] is None: 23 | target_format.append( 24 | { 25 | "id": prob_id, 26 | "conversations": [ 27 | {"from": "human", "value": f"{input}"}, 28 | {"from": "gpt", "value": f"{output}"}, 29 | ], 30 | } 31 | ) 32 | 33 | else: 34 | target_format.append( 35 | { 36 | "id": prob_id, 37 | "image": os.path.join(prob_id, raw_prob_data["image"]), 38 | "conversations": [ 39 | {"from": "human", "value": f"{input}\n"}, 40 | {"from": "gpt", "value": f"{output}"}, 41 | ], 42 | } 43 | ) 44 | 45 | print(f"Number of samples: {len(target_format)}") 46 | 47 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 48 | json.dump(target_format, f, indent=2) 49 | 50 | 51 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 52 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 53 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 54 | 55 | split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith("Question: "): 60 | input = input.replace("Question: ", "") 61 | if output.startswith("Answer: "): 62 | output = output.replace("Answer: ", "") 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data["image"] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data["image"]), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + "\n") 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /scripts/archived/convert_vizwiz_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from llava.eval.m4c_evaluator import EvalAIAnswerProcessor 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--annotation-file", type=str, required=True) 11 | parser.add_argument("--result-file", type=str, required=True) 12 | parser.add_argument("--result-upload-file", type=str, required=True) 13 | return parser.parse_args() 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | args = parse_args() 19 | 20 | os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True) 21 | 22 | results = [] 23 | error_line = 0 24 | for line_idx, line in enumerate(open(args.result_file)): 25 | try: 26 | results.append(json.loads(line)) 27 | except: 28 | error_line += 1 29 | results = {x["question_id"]: x["text"] for x in results} 30 | test_split = [json.loads(line) for line in open(args.annotation_file)] 31 | split_ids = set([x["question_id"] for x in test_split]) 32 | 33 | print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}") 34 | 35 | all_answers = [] 36 | 37 | answer_processor = EvalAIAnswerProcessor() 38 | 39 | for x in test_split: 40 | # import pdb; pdb.set_trace() 41 | assert x["question_id"] in results, print(x) 42 | all_answers.append({"image": x["image"], "answer": answer_processor(results[x["question_id"]])}) 43 | 44 | with open(args.result_upload_file, "w") as f: 45 | json.dump(all_answers, f) 46 | -------------------------------------------------------------------------------- /scripts/archived/convert_vqav2_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from llava.eval.m4c_evaluator import EvalAIAnswerProcessor 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--dir", type=str, default="./playground/data/eval/vqav2") 11 | parser.add_argument("--ckpt", type=str, required=True) 12 | parser.add_argument("--split", type=str, required=True) 13 | return parser.parse_args() 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | args = parse_args() 19 | 20 | src = os.path.join(args.dir, "answers", args.split, args.ckpt, "merge.jsonl") 21 | test_split = os.path.join(args.dir, "llava_vqav2_mscoco_test2015.jsonl") 22 | dst = os.path.join(args.dir, "answers_upload", args.split, f"{args.ckpt}.json") 23 | os.makedirs(os.path.dirname(dst), exist_ok=True) 24 | 25 | results = [] 26 | error_line = 0 27 | for line_idx, line in enumerate(open(src)): 28 | try: 29 | results.append(json.loads(line)) 30 | except: 31 | error_line += 1 32 | 33 | results = {x["question_id"]: x["text"] for x in results} 34 | test_split = [json.loads(line) for line in open(test_split)] 35 | split_ids = set([x["question_id"] for x in test_split]) 36 | 37 | print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}") 38 | 39 | all_answers = [] 40 | 41 | answer_processor = EvalAIAnswerProcessor() 42 | 43 | for x in test_split: 44 | if x["question_id"] not in results: 45 | all_answers.append({"question_id": x["question_id"], "answer": ""}) 46 | else: 47 | all_answers.append({"question_id": x["question_id"], "answer": answer_processor(results[x["question_id"]])}) 48 | 49 | with open(dst, "w") as f: 50 | json.dump(all_answers, open(dst, "w")) 51 | -------------------------------------------------------------------------------- /scripts/archived/dpo_data_info.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | json_path = "/mnt/bn/vl-research/workspace/boli01/projects/sft_data_workspace/vlfeedback_80k.jsonl" 5 | 6 | with open(json_path, "r") as f: 7 | data = f.readlines() 8 | 9 | data = [json.loads(d) for d in data] 10 | 11 | 12 | def convert_format(original_data, dimension="Visual Faithfulness"): 13 | converted_data = [] 14 | for item in original_data: 15 | # Assuming the best response is the one with the highest helpfulness rating 16 | best_completion = max(item["completions"], key=lambda x: int(x["annotations"]["Helpfulness"]["Rating"])) 17 | best_response = best_completion["response"] 18 | best_model = best_completion["model"] 19 | 20 | if "†source" in best_response: 21 | print(best_response) 22 | # Regex pattern to match the pattern 【digit†source】 23 | pattern = r"【\d+†source】" 24 | # Replace the matched patterns with an empty string 25 | cleaned_text = re.sub(pattern, "", best_response) 26 | best_response = cleaned_text 27 | print(f"*****************************************") 28 | print(best_response) 29 | 30 | # Assuming the worst response is the one with the lowest helpfulness rating 31 | worst_completion = min(item["completions"], key=lambda x: int(x["annotations"]["Helpfulness"]["Rating"])) 32 | worst_response = worst_completion["response"] 33 | 34 | if "†source" in worst_response: 35 | print(worst_response) 36 | # Regex pattern to match the pattern ��digit†source】 37 | pattern = r"【\d+†source】" 38 | # Replace the matched patterns with an empty string 39 | cleaned_text = re.sub(pattern, "", worst_response) 40 | worst_response = cleaned_text 41 | print(f"*****************************************") 42 | print(worst_response) 43 | 44 | # Extract scores 45 | best_score = int(best_completion["annotations"][dimension]["Rating"]) 46 | worst_score = int(worst_completion["annotations"][dimension]["Rating"]) 47 | 48 | # Construct the new format 49 | new_item = { 50 | "id": item["id"], 51 | "prompt": item["prompt"], 52 | "answer": "", 53 | "image": f"silkie_dpo/{item['id']}.jpg", # Assuming the video ID is the last part of the original ID 54 | "chosen": best_response, 55 | "rejected": worst_response, 56 | "chosen_score": best_score, 57 | "rejected_score": worst_score, 58 | } 59 | converted_data.append(new_item) 60 | 61 | return converted_data 62 | 63 | 64 | for dimension in ["Visual Faithfulness", "Helpfulness", "Ethical Considerations"]: 65 | converted_data = convert_format(data, dimension=dimension) 66 | with open(f"/mnt/bn/vl-research/data/llava_instruct/dpo_data/silkie_dpo_data_{dimension.replace(' ', '_').lower()}_{len(converted_data)}.json", "w") as f: 67 | json.dump(converted_data, f, indent=4) 68 | -------------------------------------------------------------------------------- /scripts/archived/entry_cmd.sh: -------------------------------------------------------------------------------- 1 | python3 -m pip install --upgrade pip; 2 | 3 | export http_proxy=http://sys-proxy-rd-relay.byted.org:8118; 4 | export https_proxy=http://sys-proxy-rd-relay.byted.org:8118; 5 | 6 | export HF_HOME=/mnt/bn/vl-research-boli01-cn/.cache/huggingface; 7 | export HF_TOKEN="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"; 8 | export HF_HUB_ENABLE_HF_TRANSFER="1"; 9 | 10 | cd /mnt/bn/vl-research-boli01-cn/projects/zzz/lmms-eval; 11 | pip install -e .; 12 | 13 | cd /mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next; 14 | pip install -e .; 15 | 16 | python3 -m pip install ninja; 17 | python3 -m pip install flash-attn --no-build-isolation; 18 | 19 | bash /mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next/cn_scripts/vicuna/internal0.6m_finetune_llava1.6mix_7b_v0.2_unfreeze.sh 20 | 21 | 22 | accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ 23 | --model llava \ 24 | --model_args pretrained="/mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next/internal_project_checkpoints/llavanext-lmsys_vicuna-7b-v1.5-clip-vit-large-patch14-336-mlp2x_gelu-pretrain_internal0.6m_vicuna_v1_finetune_llava1.6_datamix_unfreezeVIS_1e" \ 25 | --tasks ok_vqa,textcaps_val,mme_test,mmmu,cmmmu,coco2017_cap_val,vizwiz_vqa_val,ai2d,chartqa,pope \ 26 | --batch_size 1 \ 27 | --log_samples \ 28 | --log_samples_suffix debug \ 29 | --output_path ./logs/ \ 30 | --wandb_args 'project=llava-next-lmms-eval,job_type=eval'; -------------------------------------------------------------------------------- /scripts/archived/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA 4 | 5 | # Install yolk3k if not installed 6 | if ! pip show yolk3k > /dev/null 2>&1; then 7 | pip install yolk3k 8 | fi 9 | 10 | # Get the installed version of transformers 11 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 12 | 13 | # Get the latest version of transformers from PyPI 14 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 15 | 16 | # Check if the installed version is not the latest 17 | if [ "$installed_version" != "$latest_version" ]; then 18 | pip install -U transformers 19 | fi 20 | 21 | # Get the installed version of deepspeed 22 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 23 | 24 | # Get the latest version of deepspeed from PyPI 25 | latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) 26 | 27 | # Check if the installed version is not the latest 28 | # pip install deepspeed==0.12.2 29 | if [ "$installed_version" != "$latest_version" ]; then 30 | pip install deepspeed==0.12.2 31 | fi 32 | 33 | # Install flash-attn if not installed 34 | if ! pip show flash-attn > /dev/null 2>&1; then 35 | pip install flash-attn --no-build-isolation 36 | fi 37 | 38 | ################## VICUNA ################## 39 | PROMPT_VERSION=v1 40 | MODEL_VERSION="vicuna-7b-v1-5" 41 | ################## VICUNA ################## 42 | 43 | 44 | ################## project ################## 45 | PROJECT_NAME="ds_llava-vicuna-7b-v1-5-mlp2x_gelu-pretrain_blip558k_plain" 46 | 47 | ################## data ################## 48 | DATA_NAME="mixtral_instruct_158K_V1" 49 | 50 | # wandb configure 51 | export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" 52 | wandb login $WANDB_API_KEY 53 | 54 | export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME 55 | 56 | export WANDB_PROJECT=LLaVA_Mixtral 57 | 58 | export WANDB_MODE=online 59 | 60 | # wandb online 61 | 62 | deepspeed --master_port 26000 \ 63 | llava/train/train_mem.py \ 64 | --deepspeed ./scripts/zero2.json \ 65 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 66 | --version $PROMPT_VERSION \ 67 | --data_path ./playground/data/$DATA_NAME.json \ 68 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \ 69 | --vision_tower openai/clip-vit-large-patch14 \ 70 | --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ 71 | --mm_vision_select_layer -2 \ 72 | --mm_projector_type mlp2x_gelu \ 73 | --mm_use_im_start_end False \ 74 | --mm_use_im_patch_token False \ 75 | --bf16 True \ 76 | --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ 77 | --num_train_epochs 1 \ 78 | --per_device_train_batch_size 16 \ 79 | --per_device_eval_batch_size 4 \ 80 | --gradient_accumulation_steps 1 \ 81 | --evaluation_strategy "no" \ 82 | --save_strategy "steps" \ 83 | --save_steps 50000 \ 84 | --save_total_limit 1 \ 85 | --learning_rate 2e-5 \ 86 | --weight_decay 0. \ 87 | --warmup_ratio 0.03 \ 88 | --lr_scheduler_type "cosine" \ 89 | --logging_steps 1 \ 90 | --tf32 True \ 91 | --model_max_length 2048 \ 92 | --gradient_checkpointing True \ 93 | --dataloader_num_workers 16 \ 94 | --lazy_preprocess True \ 95 | --report_to wandb 96 | -------------------------------------------------------------------------------- /scripts/archived/finetune_1.5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset_name=$1 3 | 4 | # Uncomment and set the following variables correspondingly to run this script: 5 | 6 | cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA 7 | 8 | # Install yolk3k if not installed 9 | if ! pip show yolk3k > /dev/null 2>&1; then 10 | pip install yolk3k 11 | fi 12 | 13 | # Get the installed version of transformers 14 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 15 | 16 | # Get the latest version of transformers from PyPI 17 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 18 | 19 | # Check if the installed version is not the latest 20 | if [ "$installed_version" != "$latest_version" ]; then 21 | pip install -U transformers 22 | fi 23 | 24 | # Get the installed version of deepspeed 25 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 26 | 27 | # Get the latest version of deepspeed from PyPI 28 | latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) 29 | 30 | # Check if the installed version is not the latest 31 | if [ "$installed_version" != "$latest_version" ]; then 32 | pip install deepspeed==0.12.2 33 | fi 34 | 35 | # Install yolk3k if not installed 36 | if ! pip show flash-attn > /dev/null 2>&1; then 37 | pip install flash-attn --no-build-isolation 38 | fi 39 | 40 | 41 | ################## VICUNA ################## 42 | PROMPT_VERSION=v1 43 | MODEL_VERSION="vicuna-7b-v1-5" 44 | ################## VICUNA ################## 45 | 46 | ################## project ################## 47 | PROJECT_NAME="ds_llava-vicuna-7b-v1-5-mlp2x_gelu-pretrain_blip558k_plain" 48 | 49 | ################## data ################## 50 | DATA_NAME=$dataset_name 51 | 52 | 53 | # wandb configure 54 | export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" 55 | wandb login $WANDB_API_KEY 56 | 57 | export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME 58 | 59 | export WANDB_PROJECT=LLaVA_Mixtral 60 | 61 | export WANDB_MODE=online 62 | 63 | wandb online 64 | 65 | 66 | deepspeed --master_port 26000 \ 67 | llava/train/train_mem.py \ 68 | --deepspeed ./scripts/zero2.json \ 69 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 70 | --version $PROMPT_VERSION \ 71 | --data_path ./playground/data/$DATA_NAME.json \ 72 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ 73 | --vision_tower openai/clip-vit-large-patch14 \ 74 | --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ 75 | --mm_vision_select_layer -2 \ 76 | --mm_projector_type mlp2x_gelu \ 77 | --mm_use_im_start_end False \ 78 | --mm_use_im_patch_token False \ 79 | --bf16 True \ 80 | --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ 81 | --num_train_epochs 1 \ 82 | --per_device_train_batch_size 16 \ 83 | --per_device_eval_batch_size 4 \ 84 | --gradient_accumulation_steps 1 \ 85 | --evaluation_strategy "no" \ 86 | --save_strategy "steps" \ 87 | --save_steps 50000 \ 88 | --save_total_limit 1 \ 89 | --learning_rate 2e-5 \ 90 | --weight_decay 0. \ 91 | --warmup_ratio 0.03 \ 92 | --lr_scheduler_type "cosine" \ 93 | --logging_steps 1 \ 94 | --tf32 True \ 95 | --model_max_length 2048 \ 96 | --gradient_checkpointing True \ 97 | --dataloader_num_workers 16 \ 98 | --lazy_preprocess True \ 99 | --report_to wandb 100 | -------------------------------------------------------------------------------- /scripts/archived/finetune_full_schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 18 | --version $PROMPT_VERSION \ 19 | --data_path ./playground/data/llava_instruct_158k.json \ 20 | --image_folder /path/to/coco/train2017 \ 21 | --vision_tower openai/clip-vit-large-patch14 \ 22 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 23 | --mm_vision_select_layer -2 \ 24 | --mm_use_im_start_end False \ 25 | --mm_use_im_patch_token False \ 26 | --bf16 True \ 27 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \ 28 | --num_train_epochs 3 \ 29 | --per_device_train_batch_size 16 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 1 \ 32 | --evaluation_strategy "no" \ 33 | --save_strategy "steps" \ 34 | --save_steps 50000 \ 35 | --save_total_limit 1 \ 36 | --learning_rate 2e-5 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.03 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 1 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 16 \ 45 | --lazy_preprocess True \ 46 | --report_to wandb 47 | -------------------------------------------------------------------------------- /scripts/archived/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --lora_enable True \ 18 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 19 | --version $PROMPT_VERSION \ 20 | --data_path ./playground/data/llava_instruct_80k.json \ 21 | --image_folder /path/to/coco/train2017 \ 22 | --vision_tower openai/clip-vit-large-patch14 \ 23 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 24 | --mm_vision_select_layer -2 \ 25 | --mm_use_im_start_end False \ 26 | --mm_use_im_patch_token False \ 27 | --bf16 True \ 28 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ 29 | --num_train_epochs 1 \ 30 | --per_device_train_batch_size 16 \ 31 | --per_device_eval_batch_size 4 \ 32 | --gradient_accumulation_steps 1 \ 33 | --evaluation_strategy "no" \ 34 | --save_strategy "steps" \ 35 | --save_steps 50000 \ 36 | --save_total_limit 1 \ 37 | --learning_rate 2e-5 \ 38 | --weight_decay 0. \ 39 | --warmup_ratio 0.03 \ 40 | --lr_scheduler_type "cosine" \ 41 | --logging_steps 1 \ 42 | --tf32 True \ 43 | --model_max_length 2048 \ 44 | --gradient_checkpointing True \ 45 | --lazy_preprocess True \ 46 | --dataloader_num_workers 16 \ 47 | --report_to wandb 48 | -------------------------------------------------------------------------------- /scripts/archived/finetune_mixtral.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA 4 | 5 | # Install yolk3k if not installed 6 | if ! pip show yolk3k > /dev/null 2>&1; then 7 | pip install yolk3k 8 | fi 9 | 10 | # Get the installed version of transformers 11 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 12 | 13 | # Get the latest version of transformers from PyPI 14 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 15 | 16 | # Check if the installed version is not the latest 17 | if [ "$installed_version" != "$latest_version" ]; then 18 | pip install -U transformers 19 | fi 20 | 21 | # Get the installed version of deepspeed 22 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 23 | 24 | # Get the latest version of deepspeed from PyPI 25 | latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) 26 | 27 | # Check if the installed version is not the latest 28 | if [ "$installed_version" != "$latest_version" ]; then 29 | pip install deepspeed==0.12.2 30 | fi 31 | 32 | # Install yolk3k if not installed 33 | if ! pip show flash-attn > /dev/null 2>&1; then 34 | pip install flash-attn --no-build-isolation 35 | fi 36 | 37 | 38 | ################## MISTRAL ################## 39 | PROMPT_VERSION=mistral_instruct 40 | MODEL_VERSION="Mistral-7B-Instruct-v0.2" 41 | ################## VICUNA ################## 42 | 43 | 44 | ################## project ################## 45 | PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain" 46 | 47 | ################## data ################## 48 | DATA_NAME="mixtral_instruct_158K_V1" 49 | 50 | # wandb configure 51 | export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" 52 | wandb login $WANDB_API_KEY 53 | 54 | export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME 55 | 56 | export WANDB_PROJECT=LLaVA_Mixtral 57 | 58 | export WANDB_MODE=online 59 | 60 | wandb online 61 | 62 | 63 | deepspeed --master_port 26000 \ 64 | llava/train/train_mem.py \ 65 | --deepspeed ./scripts/zero2.json \ 66 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 67 | --version $PROMPT_VERSION \ 68 | --data_path ./playground/data/$DATA_NAME.json \ 69 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \ 70 | --vision_tower openai/clip-vit-large-patch14 \ 71 | --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ 72 | --mm_vision_select_layer -2 \ 73 | --mm_projector_type mlp2x_gelu \ 74 | --mm_use_im_start_end False \ 75 | --mm_use_im_patch_token False \ 76 | --bf16 True \ 77 | --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ 78 | --num_train_epochs 1 \ 79 | --per_device_train_batch_size 16 \ 80 | --per_device_eval_batch_size 4 \ 81 | --gradient_accumulation_steps 1 \ 82 | --evaluation_strategy "no" \ 83 | --save_strategy "steps" \ 84 | --save_steps 50000 \ 85 | --save_total_limit 1 \ 86 | --learning_rate 2e-5 \ 87 | --weight_decay 0. \ 88 | --warmup_ratio 0.03 \ 89 | --lr_scheduler_type "cosine" \ 90 | --logging_steps 1 \ 91 | --tf32 True \ 92 | --model_max_length 2048 \ 93 | --gradient_checkpointing True \ 94 | --dataloader_num_workers 16 \ 95 | --lazy_preprocess True \ 96 | --report_to wandb 97 | -------------------------------------------------------------------------------- /scripts/archived/finetune_mixtral_1.5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset_name=$1 3 | 4 | cd /mnt/bn/vl-research/workspace/yhzhang/LLaVA 5 | 6 | # Install yolk3k if not installed 7 | if ! pip show yolk3k > /dev/null 2>&1; then 8 | pip install yolk3k 9 | fi 10 | 11 | # Get the installed version of transformers 12 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 13 | 14 | # Get the latest version of transformers from PyPI 15 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 16 | 17 | # Check if the installed version is not the latest 18 | if [ "$installed_version" != "$latest_version" ]; then 19 | pip install -U transformers 20 | fi 21 | 22 | # Get the installed version of deepspeed 23 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 24 | 25 | # Get the latest version of deepspeed from PyPI 26 | # latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) 27 | 28 | # Check if the installed version is not the latest 29 | if [ "$installed_version" != "0.12.2" ]; then 30 | pip install deepspeed==0.12.2 31 | fi 32 | 33 | # Install yolk3k if not installed 34 | if ! pip show flash-attn > /dev/null 2>&1; then 35 | pip install flash-attn --no-build-isolation 36 | fi 37 | 38 | ################## MISTRAL ################## 39 | PROMPT_VERSION=mistral_instruct 40 | MODEL_VERSION="Mistral-7B-Instruct-v0.2" 41 | ################## MISTRAL ################## 42 | 43 | 44 | ################## project ################## 45 | PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain" 46 | 47 | ################## data ################## 48 | DATA_NAME=$dataset_name 49 | 50 | 51 | # wandb configure 52 | export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" 53 | wandb login $WANDB_API_KEY 54 | 55 | export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME 56 | 57 | export WANDB_PROJECT=LLaVA_Mixtral 58 | 59 | export WANDB_MODE=online 60 | 61 | wandb online 62 | 63 | deepspeed --master_port 26000 \ 64 | llava/train/train_mem.py \ 65 | --deepspeed ./scripts/zero2.json \ 66 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 67 | --version $PROMPT_VERSION \ 68 | --data_path ./playground/data/$DATA_NAME.json \ 69 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ 70 | --vision_tower openai/clip-vit-large-patch14 \ 71 | --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ 72 | --mm_vision_select_layer -2 \ 73 | --mm_projector_type mlp2x_gelu \ 74 | --mm_use_im_start_end False \ 75 | --mm_use_im_patch_token False \ 76 | --bf16 True \ 77 | --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ 78 | --num_train_epochs 1 \ 79 | --per_device_train_batch_size 16 \ 80 | --per_device_eval_batch_size 4 \ 81 | --gradient_accumulation_steps 1 \ 82 | --evaluation_strategy "no" \ 83 | --save_strategy "steps" \ 84 | --save_steps 50000 \ 85 | --save_total_limit 1 \ 86 | --learning_rate 2e-5 \ 87 | --weight_decay 0. \ 88 | --warmup_ratio 0.03 \ 89 | --lr_scheduler_type "cosine" \ 90 | --logging_steps 1 \ 91 | --tf32 True \ 92 | --model_max_length 2048 \ 93 | --gradient_checkpointing True \ 94 | --dataloader_num_workers 16 \ 95 | --lazy_preprocess True 96 | # --report_to wandb 97 | -------------------------------------------------------------------------------- /scripts/archived/finetune_mixtral_1.6_336px_anyres.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset_name=$1 3 | 4 | cd /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next 5 | 6 | # Install yolk3k if not installed 7 | if ! pip show yolk3k > /dev/null 2>&1; then 8 | pip install yolk3k 9 | fi 10 | 11 | pip install pydantic 12 | 13 | # Get the installed version of transformers 14 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 15 | 16 | # Get the latest version of transformers from PyPI 17 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 18 | 19 | # Check if the installed version is not the latest 20 | if [ "$installed_version" != "4.36.2" ]; then 21 | pip install transformers==4.36.2 22 | fi 23 | 24 | # Get the installed version of deepspeed 25 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 26 | 27 | 28 | # Check if the installed version is not the latest 29 | if [ "$installed_version" != "0.12.2" ]; then 30 | pip install deepspeed==0.12.2 31 | fi 32 | 33 | # Install flash-atten if not installed 34 | if ! pip show flash-attn > /dev/null 2>&1; then 35 | pip install flash-attn --no-build-isolation 36 | fi 37 | 38 | ################## MISTRAL ################## 39 | PROMPT_VERSION=mistral_instruct 40 | MODEL_VERSION="Mistral-7B-Instruct-v0.2" 41 | ################## MISTRAL ################## 42 | 43 | 44 | ################## project ################## 45 | PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain" 46 | 47 | ################## data ################## 48 | DATA_NAME=$dataset_name 49 | 50 | 51 | # wandb configure 52 | export WANDB_API_KEY=e464cc107357c7b38e87f239bc3eb2ce5fb73c7c 53 | export WANDB_PROJECT=llava 54 | 55 | export WANDB_NAME=$PROJECT_NAME--$DATA_NAME--336px--anyres--sft 56 | 57 | export WANDB_MODE=online 58 | 59 | wandb online 60 | 61 | deepspeed --master_port 26000 \ 62 | llava/train/train_mem.py \ 63 | --deepspeed ./scripts/zero3.json \ 64 | --model_name_or_path /mnt/bn/vl-research/workspace/project/2023/LLaVA/checkpoints/$MODEL_VERSION \ 65 | --version $PROMPT_VERSION \ 66 | --data_path ./playground/data/$DATA_NAME.json \ 67 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ 68 | --vision_tower openai/clip-vit-large-patch14-336 \ 69 | --pretrain_mm_mlp_adapter /mnt/bn/vl-research/workspace/project/2023/LLaVA/checkpoints/ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain/mm_projector.bin \ 70 | --mm_projector_type mlp2x_gelu \ 71 | --mm_vision_select_layer -2 \ 72 | --mm_use_im_start_end False \ 73 | --mm_use_im_patch_token False \ 74 | --group_by_modality_length True \ 75 | --unfreeze_mm_vision_tower True \ 76 | --mm_vision_tower_lr 2e-6 \ 77 | --image_aspect_ratio anyres \ 78 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ 79 | --mm_patch_merge_type spatial_unpad \ 80 | --bf16 True \ 81 | --output_dir ./checkpoints/$PROJECT_NAME--$DATA_NAME--336px--anyres--sft \ 82 | --num_train_epochs 9 \ 83 | --per_device_train_batch_size 8 \ 84 | --per_device_eval_batch_size 4 \ 85 | --gradient_accumulation_steps 1 \ 86 | --evaluation_strategy "no" \ 87 | --save_strategy "epoch" \ 88 | --save_steps 1500 \ 89 | --learning_rate 5e-6 \ 90 | --weight_decay 0. \ 91 | --warmup_ratio 0.03 \ 92 | --lr_scheduler_type "cosine" \ 93 | --logging_steps 1 \ 94 | --tf32 True \ 95 | --model_max_length 4096 \ 96 | --gradient_checkpointing True \ 97 | --dataloader_num_workers 8 \ 98 | --lazy_preprocess True \ 99 | --report_to wandb 100 | 101 | -------------------------------------------------------------------------------- /scripts/archived/finetune_mixtral_1.6_336px_anyres_freeze_vision.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset_name=$1 3 | 4 | cd /mnt/bn/vl-research/workspace/yhzhang/LLaVA 5 | 6 | # Install yolk3k if not installed 7 | if ! pip show yolk3k > /dev/null 2>&1; then 8 | pip install yolk3k 9 | fi 10 | 11 | pip install pydantic 12 | 13 | # Get the installed version of transformers 14 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 15 | 16 | # Get the latest version of transformers from PyPI 17 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 18 | 19 | # Check if the installed version is not the latest 20 | if [ "$installed_version" != "4.36.2" ]; then 21 | pip install transformers==4.36.2 22 | fi 23 | 24 | # Get the installed version of deepspeed 25 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 26 | 27 | 28 | # Check if the installed version is not the latest 29 | if [ "$installed_version" != "0.12.2" ]; then 30 | pip install deepspeed==0.12.2 31 | fi 32 | 33 | # Install flash-atten if not installed 34 | if ! pip show flash-attn > /dev/null 2>&1; then 35 | pip install flash-attn --no-build-isolation 36 | fi 37 | 38 | ################## MISTRAL ################## 39 | PROMPT_VERSION=mistral_instruct 40 | MODEL_VERSION="Mistral-7B-Instruct-v0.2" 41 | ################## MISTRAL ################## 42 | 43 | 44 | ################## project ################## 45 | PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain" 46 | 47 | ################## data ################## 48 | DATA_NAME=$dataset_name 49 | 50 | 51 | # wandb configure 52 | export WANDB_API_KEY=e464cc107357c7b38e87f239bc3eb2ce5fb73c7c 53 | export WANDB_PROJECT=llava 54 | 55 | export WANDB_NAME=$PROJECT_NAME--$DATA_NAME--336px--unfreeze--anyres--sft 56 | 57 | export WANDB_MODE=online 58 | 59 | wandb online 60 | 61 | deepspeed --master_port 26000 \ 62 | llava/train/train_mem.py \ 63 | --deepspeed ./scripts/zero3.json \ 64 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 65 | --version $PROMPT_VERSION \ 66 | --data_path ./playground/data/$DATA_NAME.json \ 67 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ 68 | --vision_tower openai/clip-vit-large-patch14-336 \ 69 | --pretrain_mm_mlp_adapter /mnt/bn/vl-research/workspace/project/2023/LLaVA/checkpoints/ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain/mm_projector.bin \ 70 | --mm_vision_select_layer -2 \ 71 | --mm_projector_type mlp2x_gelu \ 72 | --mm_use_im_start_end False \ 73 | --mm_use_im_patch_token False \ 74 | --group_by_modality_length True \ 75 | --image_aspect_ratio anyres \ 76 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ 77 | --mm_patch_merge_type spatial_unpad \ 78 | --bf16 True \ 79 | --output_dir ./checkpoints/$PROJECT_NAME--$DATA_NAME--336px--anyres--unfreeze--sft \ 80 | --num_train_epochs 1 \ 81 | --per_device_train_batch_size 16 \ 82 | --per_device_eval_batch_size 4 \ 83 | --gradient_accumulation_steps 1 \ 84 | --evaluation_strategy "no" \ 85 | --save_strategy "steps" \ 86 | --save_steps 50000 \ 87 | --save_total_limit 1 \ 88 | --learning_rate 2e-5 \ 89 | --weight_decay 0. \ 90 | --warmup_ratio 0.03 \ 91 | --lr_scheduler_type "cosine" \ 92 | --logging_steps 1 \ 93 | --tf32 True \ 94 | --model_max_length 2048 \ 95 | --gradient_checkpointing True \ 96 | --dataloader_num_workers 16 \ 97 | 98 | -------------------------------------------------------------------------------- /scripts/archived/finetune_mixtral_1.6_336px_anyres_lmms_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set up wandb 4 | export WANDB_API_KEY=a651c244635bc6f913ab654af3f0eebaecdc9381 5 | export WANDB_ENTITY=llava-vl 6 | export WANDB_PROJECT=llava-next 7 | export PYTHONWARNINGS="ignore" 8 | 9 | cd /mnt/bn/vl-research/workspace/boli01/projects/lmms-eval 10 | 11 | pip install -e . 12 | 13 | # set up llava dev env 14 | cd /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next 15 | 16 | ################## MISTRAL ################## 17 | PROMPT_VERSION=mistral_instruct 18 | MODEL_VERSION="Mistral-7B-Instruct-v0.2" 19 | ################## MISTRAL ################## 20 | 21 | ################## project ################## 22 | PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain" 23 | 24 | ################## data ################## 25 | DATA_NAME='llava_caps20k_chartqa19k' 26 | 27 | export WANDB_NAME=$PROJECT_NAME--$DATA_NAME--336px--anyres--sft 28 | export WANDB_MODE=online 29 | 30 | wandb online 31 | 32 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" deepspeed --master_port 26000 --include localhost:0,1,2,3,4,5,6,7 llava/train/train_mem.py \ 33 | --deepspeed ./scripts/zero3_offload.json \ 34 | --model_name_or_path mistralai/$MODEL_VERSION \ 35 | --version $PROMPT_VERSION \ 36 | --data_path ./playground/data/llava_instruct/$DATA_NAME.json \ 37 | --image_folder /mnt/bn/vl-research/data/llava \ 38 | --vision_tower openai/clip-vit-large-patch14-336 \ 39 | --mm_projector_type mlp2x_gelu \ 40 | --mm_vision_select_layer -2 \ 41 | --mm_use_im_start_end False \ 42 | --mm_use_im_patch_token False \ 43 | --group_by_modality_length True \ 44 | --unfreeze_mm_vision_tower True \ 45 | --mm_vision_tower_lr 2e-6 \ 46 | --image_aspect_ratio anyres \ 47 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ 48 | --mm_patch_merge_type spatial_unpad \ 49 | --bf16 True \ 50 | --output_dir ./checkpoints/$PROJECT_NAME--llava1.6--336px--anyres--sft \ 51 | --num_train_epochs 1 \ 52 | --per_device_train_batch_size 8 \ 53 | --per_device_eval_batch_size 4 \ 54 | --gradient_accumulation_steps 1 \ 55 | --evaluation_strategy "no" \ 56 | --save_strategy "steps" \ 57 | --save_steps 1500 \ 58 | --learning_rate 2e-5 \ 59 | --weight_decay 0. \ 60 | --warmup_ratio 0.03 \ 61 | --lr_scheduler_type "cosine" \ 62 | --logging_steps 1 \ 63 | --tf32 True \ 64 | --model_max_length 4096 \ 65 | --gradient_checkpointing True \ 66 | --dataloader_num_workers 32 \ 67 | --lazy_preprocess True \ 68 | --report_to wandb \ 69 | --run_name $WANDB_NAME 70 | # starting here is the args for evaluation 71 | --eval_num_processes 4 \ 72 | --task_names mme,docvqa_val \ 73 | --model_args pretrained=./checkpoints/$PROJECT_NAME--$DATA_NAME--336px--anyres--sft \ 74 | --limit 8 \ 75 | --batch_size 1 \ 76 | --log_samples \ 77 | --log_samples_suffix debug \ 78 | --output_path ./logs/ 79 | -------------------------------------------------------------------------------- /scripts/archived/finetune_mixtral_copy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA 4 | 5 | # Install yolk3k if not installed 6 | if ! pip show yolk3k > /dev/null 2>&1; then 7 | pip install yolk3k 8 | fi 9 | 10 | # Get the installed version of transformers 11 | installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) 12 | 13 | # Get the latest version of transformers from PyPI 14 | latest_version=$(yolk -V transformers | cut -d ' ' -f 2) 15 | 16 | # Check if the installed version is not the latest 17 | if [ "$installed_version" != "$latest_version" ]; then 18 | pip install -U transformers 19 | fi 20 | 21 | # Get the installed version of deepspeed 22 | installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) 23 | 24 | # Get the latest version of deepspeed from PyPI 25 | latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) 26 | 27 | # Check if the installed version is not the latest 28 | if [ "$installed_version" != "$latest_version" ]; then 29 | pip install deepspeed==0.12.2 30 | fi 31 | 32 | # Install yolk3k if not installed 33 | if ! pip show flash-attn > /dev/null 2>&1; then 34 | pip install flash-attn --no-build-isolation 35 | fi 36 | 37 | 38 | ################## MISTRAL ################## 39 | PROMPT_VERSION=mistral_instruct 40 | MODEL_VERSION="Mistral-7B-Instruct-v0.2" 41 | ################## VICUNA ################## 42 | 43 | 44 | ################## project ################## 45 | PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain" 46 | 47 | ################## data ################## 48 | DATA_NAME="llava_instruct_150k" 49 | 50 | # wandb configure 51 | export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" 52 | wandb login $WANDB_API_KEY 53 | 54 | export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME 55 | 56 | export WANDB_PROJECT=LLaVA_Mixtral 57 | 58 | export WANDB_MODE=online 59 | 60 | wandb online 61 | 62 | 63 | deepspeed --master_port 26000 \ 64 | llava/train/train_mem.py \ 65 | --deepspeed ./scripts/zero2.json \ 66 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 67 | --version $PROMPT_VERSION \ 68 | --data_path ./playground/data/$DATA_NAME.json \ 69 | --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \ 70 | --vision_tower openai/clip-vit-large-patch14 \ 71 | --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ 72 | --mm_vision_select_layer -2 \ 73 | --mm_projector_type mlp2x_gelu \ 74 | --mm_use_im_start_end False \ 75 | --mm_use_im_patch_token False \ 76 | --bf16 True \ 77 | --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ 78 | --num_train_epochs 1 \ 79 | --per_device_train_batch_size 16 \ 80 | --per_device_eval_batch_size 4 \ 81 | --gradient_accumulation_steps 1 \ 82 | --evaluation_strategy "no" \ 83 | --save_strategy "steps" \ 84 | --save_steps 50000 \ 85 | --save_total_limit 1 \ 86 | --learning_rate 2e-5 \ 87 | --weight_decay 0. \ 88 | --warmup_ratio 0.03 \ 89 | --lr_scheduler_type "cosine" \ 90 | --logging_steps 1 \ 91 | --tf32 True \ 92 | --model_max_length 2048 \ 93 | --gradient_checkpointing True \ 94 | --dataloader_num_workers 16 \ 95 | --lazy_preprocess True \ 96 | --report_to wandb 97 | -------------------------------------------------------------------------------- /scripts/archived/finetune_qlora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --lora_enable True \ 18 | --bits 4 \ 19 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 20 | --version $PROMPT_VERSION \ 21 | --data_path ./playground/data/llava_instruct_80k.json \ 22 | --image_folder /path/to/coco/train2017 \ 23 | --vision_tower openai/clip-vit-large-patch14 \ 24 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --bf16 True \ 29 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ 30 | --num_train_epochs 1 \ 31 | --per_device_train_batch_size 16 \ 32 | --per_device_eval_batch_size 4 \ 33 | --gradient_accumulation_steps 1 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 50000 \ 37 | --save_total_limit 1 \ 38 | --learning_rate 2e-5 \ 39 | --weight_decay 0. \ 40 | --warmup_ratio 0.03 \ 41 | --lr_scheduler_type "cosine" \ 42 | --logging_steps 1 \ 43 | --tf32 True \ 44 | --model_max_length 2048 \ 45 | --gradient_checkpointing True \ 46 | --lazy_preprocess True \ 47 | --dataloader_num_workers 16 \ 48 | --report_to wandb 49 | -------------------------------------------------------------------------------- /scripts/archived/finetune_sqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero2.json \ 5 | --model_name_or_path lmsys/vicuna-13b-v1.3 \ 6 | --version $PROMPT_VERSION \ 7 | --data_path /Data/ScienceQA/data/scienceqa/llava_train_QCM-LEA.json \ 8 | --image_folder /Data/ScienceQA/data/scienceqa/images/train \ 9 | --vision_tower openai/clip-vit-large-patch14 \ 10 | --pretrain_mm_mlp_adapter ./checkpoints/huggingface/liuhaotian/llava-pretrain-vicuna-13b-v1.3/mm_projector.bin \ 11 | --mm_vision_select_layer -2 \ 12 | --mm_use_im_start_end False \ 13 | --mm_use_im_patch_token False \ 14 | --bf16 True \ 15 | --output_dir ./checkpoints/llava-vicuna-13b-v1.3-pretrain_lcs558k_plain-ScienceQA_QCM_LEA-12e \ 16 | --num_train_epochs 12 \ 17 | --per_device_train_batch_size 16 \ 18 | --per_device_eval_batch_size 4 \ 19 | --gradient_accumulation_steps 1 \ 20 | --evaluation_strategy "no" \ 21 | --save_strategy "steps" \ 22 | --save_steps 50000 \ 23 | --save_total_limit 1 \ 24 | --learning_rate 2e-5 \ 25 | --weight_decay 0. \ 26 | --warmup_ratio 0.03 \ 27 | --lr_scheduler_type "cosine" \ 28 | --logging_steps 1 \ 29 | --tf32 True \ 30 | --model_max_length 2048 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 16 \ 33 | --lazy_preprocess True \ 34 | --report_to wandb 35 | -------------------------------------------------------------------------------- /scripts/archived/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map="cpu") 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/archived/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | # MODEL_VERSION=vicuna-v1-3-7b 6 | # MODEL_VERSION=llama-2-7b-chat 7 | 8 | ########### DO NOT CHANGE ########### 9 | ########### USE THIS FOR BOTH ########### 10 | PROMPT_VERSION=plain 11 | ########### DO NOT CHANGE ########### 12 | 13 | deepspeed llava/train/train_mem.py \ 14 | --deepspeed ./scripts/zero2.json \ 15 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 16 | --version $PROMPT_VERSION \ 17 | --data_path /path/to/pretrain_data.json \ 18 | --image_folder /path/to/images \ 19 | --vision_tower openai/clip-vit-large-patch14 \ 20 | --tune_mm_mlp_adapter True \ 21 | --mm_vision_select_layer -2 \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --bf16 True \ 25 | --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \ 26 | --num_train_epochs 1 \ 27 | --per_device_train_batch_size 16 \ 28 | --per_device_eval_batch_size 4 \ 29 | --gradient_accumulation_steps 1 \ 30 | --evaluation_strategy "no" \ 31 | --save_strategy "steps" \ 32 | --save_steps 24000 \ 33 | --learning_rate 2e-3 \ 34 | --weight_decay 0. \ 35 | --warmup_ratio 0.03 \ 36 | --lr_scheduler_type "cosine" \ 37 | --logging_steps 1 \ 38 | --tf32 True \ 39 | --model_max_length 2048 \ 40 | --gradient_checkpointing True \ 41 | --dataloader_num_workers 16 \ 42 | --lazy_preprocess True \ 43 | --report_to wandb 44 | -------------------------------------------------------------------------------- /scripts/archived/quick_check.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | import yaml 6 | 7 | 8 | def check_missing_images(json_path, images_folder): 9 | data = json.load(open(json_path, "r")) 10 | missing_data = [] 11 | 12 | for i, d in enumerate(tqdm(data)): 13 | image = d["image"] if "image" in d else "" 14 | if image != "": 15 | path = os.path.join(images_folder, image) 16 | if not os.path.exists(path): 17 | print(f"Missing image: {path}") 18 | missing_data.append(d) 19 | 20 | return missing_data 21 | 22 | 23 | def read_yaml_to_llava_data(yaml_path, images_folder): 24 | print(f"Reading YAML file: {yaml_path}") 25 | with open(yaml_path, "r") as f: 26 | data = yaml.safe_load(f) 27 | 28 | llava_json_paths = data["datasets"] 29 | for item in llava_json_paths: 30 | json_path = item["json_path"] 31 | missing_data = check_missing_images(json_path, images_folder) 32 | if len(missing_data) > 0: 33 | print(f"Missing images in {json_path}:") 34 | for d in missing_data: 35 | print(d) 36 | 37 | 38 | def direct_check_llava_data(json_path, images_folder): 39 | missing_data = check_missing_images(json_path, images_folder) 40 | if len(missing_data) > 0: 41 | print(f"Missing images in {json_path}:") 42 | for d in missing_data: 43 | print(d) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser(description="Check for missing images in dataset.") 48 | parser.add_argument("--yaml_path", type=str, default="", help="Path to the YAML file containing the dataset.") 49 | parser.add_argument("--json_path", type=str, default="", help="Path to the JSON file containing the dataset.") 50 | parser.add_argument("--images_folder", type=str, default="/mnt/bn/vl-research/data/llava_data", help="Path to the folder containing the images.") 51 | 52 | args = parser.parse_args() 53 | 54 | if args.json_path != "": 55 | direct_check_llava_data(args.json_path, args.images_folder) 56 | elif args.yaml_path != "": 57 | read_yaml_to_llava_data(args.yaml_path, args.images_folder) 58 | -------------------------------------------------------------------------------- /scripts/archived/sqa_eval_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNKS=8 4 | for IDX in {0..7}; do 5 | CUDA_VISIBLE_DEVICES=$IDX python -m llava.eval.model_vqa_science \ 6 | --model-path liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3 \ 7 | --question-file ~/haotian/datasets/ScienceQA/data/scienceqa/llava_test_QCM-LEA.json \ 8 | --image-folder ~/haotian/datasets/ScienceQA/data/scienceqa/images/test \ 9 | --answers-file ./test_llava-13b-chunk$CHUNKS_$IDX.jsonl \ 10 | --num-chunks $CHUNKS \ 11 | --chunk-idx $IDX \ 12 | --conv-mode llava_v1 & 13 | done 14 | -------------------------------------------------------------------------------- /scripts/archived/sqa_eval_gather.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNKS=8 4 | output_file="test_llava-13b.jsonl" 5 | 6 | # Clear out the output file if it exists. 7 | > "$output_file" 8 | 9 | # Loop through the indices and concatenate each file. 10 | for idx in $(seq 0 $((CHUNKS-1))); do 11 | cat "./test_llava-13b-chunk${idx}.jsonl" >> "$output_file" 12 | done 13 | 14 | python llava/eval/eval_science_qa.py \ 15 | --base-dir ~/haotian/datasets/ScienceQA/data/scienceqa \ 16 | --result-file ./test_llava-13b.jsonl \ 17 | --output-file ./test_llava-13b_output.json \ 18 | --output-result ./test_llava-13b_result.json 19 | -------------------------------------------------------------------------------- /scripts/interleave/eval_all.sh: -------------------------------------------------------------------------------- 1 | 2 | # evaluate 3 | ./scripts/interleave/eval_interleave_3d.sh /path/to/ckpt /path/to/images multi_image_in_domain 4 | ./scripts/interleave/eval_interleave_3d.sh /path/to/ckpt /path/to/images multi_image_out_domain 5 | ./scripts/interleave/eval_interleave_3d.sh /path/to/ckpt /path/to/images multi_view_in_domain -------------------------------------------------------------------------------- /scripts/interleave/eval_interleave_3d.sh: -------------------------------------------------------------------------------- 1 | alias python=python3 2 | CKPT_PATH=$1 3 | NAME=$(echo "$CKPT_PATH" | awk -F'/' '{print $NF}') 4 | echo $NAME 5 | ##### set images path 6 | DATA_PATH=$2 7 | EVAL_TYPE=$3 8 | JSON_PATH=$2/$3.json 9 | ############################### eval multi-image 10 | RESULT_NAME="logs/${NAME}/${EVAL_TYPE}" 11 | echo $RESULT_NAME 12 | 13 | mkdir -p logs/${NAME} 14 | 15 | file_path=${RESULT_NAME}/result.jsonl 16 | 17 | bash scripts/interleave/eval_multiprocess.sh \ 18 | ${CKPT_PATH} \ 19 | ${JSON_PATH} \ 20 | ${RESULT_NAME} \ 21 | ${DATA_PATH} \ 22 | "" \ 23 | 8 0 24 | 25 | python3 llava/eval/evaluate_interleave.py --result-dir ${RESULT_NAME} 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /scripts/interleave/eval_multiprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if three arguments are passed 4 | if [ "$#" -ne 7 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # Assign the command line arguments to variables 10 | model_path=$1 11 | question_path=$2 12 | base_answer_path=$3 13 | image_folder=$4 14 | extra_prompt=$5 15 | N=$6 16 | temperature=$7 17 | 18 | # Loop over each chunk/process 19 | for (( chunk_id=0; chunk_id "${base_answer_path}.jsonl" 42 | for ((i=0; i> "${base_answer_path}/result.jsonl" 45 | done 46 | # remove the unmerged files 47 | for (( chunk_id=0; chunk_id" \ 34 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 35 | --unfreeze_mm_vision_tower True \ 36 | --vision_tower ${VISION_MODEL_VERSION} \ 37 | --mm_projector_type mlp2x_gelu \ 38 | --mm_vision_select_layer -2 \ 39 | --mm_use_im_start_end False \ 40 | --mm_use_im_patch_token False \ 41 | --group_by_modality_length True \ 42 | --image_aspect_ratio anyres_max_9 \ 43 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 44 | --mm_patch_merge_type spatial_unpad \ 45 | --bf16 True \ 46 | --run_name $DPO_CLEAN_NAME \ 47 | --output_dir $OUTPUT_DIR \ 48 | --num_train_epochs $EPOCH \ 49 | --per_device_train_batch_size 1 \ 50 | --per_device_eval_batch_size 1 \ 51 | --gradient_accumulation_steps 8 \ 52 | --evaluation_strategy "no" \ 53 | --save_strategy "steps" \ 54 | --save_steps 1000 \ 55 | --save_total_limit 1 \ 56 | --learning_rate 5e-7 \ 57 | --weight_decay 0. \ 58 | --warmup_ratio 0.1 \ 59 | --lr_scheduler_type "cosine" \ 60 | --logging_steps 1 \ 61 | --tf32 True \ 62 | --model_max_length 32768 \ 63 | --gradient_checkpointing True \ 64 | --dataloader_num_workers 4 \ 65 | --lazy_preprocess True \ 66 | --report_to wandb \ 67 | --dataloader_drop_last True 68 | 69 | 70 | -------------------------------------------------------------------------------- /scripts/train/finetune_ov.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_SOCKET_IFNAME=eth0 5 | export NCCL_DEBUG=INFO 6 | 7 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 8 | # for 7b model we recommend bs=1, accum=2, 16 nodes, 128 gpus, lr=1e-5, warmup=0.03 9 | # for 72b model we recommend bs=1, accum=1, 32 nodes, 256 gpus, lr=1e-5, warmup=0.03 10 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 11 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 12 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 13 | 14 | ############### Pretrain ################ 15 | 16 | BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" 17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 18 | 19 | ############### Finetune ################ 20 | 21 | # Stage 2 22 | PROMPT_VERSION="qwen_1_5" 23 | RUN_NAME="llava-onevision-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_stage_am9" 24 | PREV_STAGE_CHECKPOINT="/mnt/bn/vl-research/checkpoints/onevision/llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_3m_am9_july14" # replace it with your last checkpoint training from single image collection 25 | echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" 26 | echo "MID_RUN_NAME: ${RUN_NAME}" 27 | 28 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 29 | llava/train/train_mem.py \ 30 | --deepspeed scripts/zero3.json \ 31 | --model_name_or_path $PREV_STAGE_CHECKPOINT \ 32 | --version $PROMPT_VERSION \ 33 | --data_path /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_ov_stage_july21.yaml \ 34 | --image_folder /mnt/bn/vl-research/data/llava_data \ 35 | --video_folder /mnt/bn/vl-research/data/llava_video \ 36 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 37 | --mm_vision_tower_lr=2e-6 \ 38 | --vision_tower ${VISION_MODEL_VERSION} \ 39 | --mm_projector_type mlp2x_gelu \ 40 | --mm_vision_select_layer -2 \ 41 | --mm_use_im_start_end False \ 42 | --mm_use_im_patch_token False \ 43 | --group_by_modality_length True \ 44 | --image_aspect_ratio anyres_max_9 \ 45 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 46 | --mm_patch_merge_type spatial_unpad \ 47 | --bf16 True \ 48 | --run_name $RUN_NAME \ 49 | --output_dir /mnt/bn/vl-research/checkpoints/onevision/$RUN_NAME \ 50 | --num_train_epochs 1 \ 51 | --per_device_train_batch_size 1 \ 52 | --per_device_eval_batch_size 4 \ 53 | --gradient_accumulation_steps 2 \ 54 | --evaluation_strategy "no" \ 55 | --save_strategy "steps" \ 56 | --save_steps 1000 \ 57 | --save_total_limit 1 \ 58 | --learning_rate 1e-5 \ 59 | --weight_decay 0. \ 60 | --warmup_ratio 0.03 \ 61 | --lr_scheduler_type "cosine" \ 62 | --logging_steps 1 \ 63 | --tf32 True \ 64 | --model_max_length 32768 \ 65 | --gradient_checkpointing True \ 66 | --dataloader_num_workers 4 \ 67 | --lazy_preprocess True \ 68 | --report_to wandb \ 69 | --torch_compile True \ 70 | --torch_compile_backend "inductor" \ 71 | --dataloader_drop_last True \ 72 | --frames_upbound 32 73 | exit 0; 74 | 75 | # You can delete the sdpa attn_implementation if you want to use flash attn 76 | -------------------------------------------------------------------------------- /scripts/train/finetune_si.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_SOCKET_IFNAME=eth0 5 | export NCCL_DEBUG=INFO 6 | 7 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 8 | # for 7b model we recommend bs=1, accum=2, 16 nodes, 128 gpus, lr=1e-5, warmup=0.03 9 | # for 72b model we recommend bs=1, accum=1, 32 nodes, 256 gpus, lr=1e-5, warmup=0.03 10 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 11 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 12 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 13 | 14 | ############### Pretrain ################ 15 | 16 | BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" 17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 18 | 19 | ############### Finetune ################ 20 | 21 | # Stage 2 22 | PROMPT_VERSION="qwen_1_5" 23 | RUN_NAME="llava-onevision-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-si_stage_am9" 24 | PREV_STAGE_CHECKPOINT="/mnt/bn/vl-research/checkpoints/onevision/xxxxxxxxxxxxxxxx" # replace it with your last checkpoint training from mid stage 25 | echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" 26 | echo "MID_RUN_NAME: ${RUN_NAME}" 27 | 28 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 29 | llava/train/train_mem.py \ 30 | --deepspeed scripts/zero3.json \ 31 | --model_name_or_path $PREV_STAGE_CHECKPOINT \ 32 | --version $PROMPT_VERSION \ 33 | --data_path /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_3p2m_single_image.yaml \ 34 | --image_folder /mnt/bn/vl-research/data/llava_data \ 35 | --video_folder /mnt/bn/vl-research/data/llava_video \ 36 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 37 | --mm_vision_tower_lr=2e-6 \ 38 | --vision_tower ${VISION_MODEL_VERSION} \ 39 | --mm_projector_type mlp2x_gelu \ 40 | --mm_vision_select_layer -2 \ 41 | --mm_use_im_start_end False \ 42 | --mm_use_im_patch_token False \ 43 | --group_by_modality_length True \ 44 | --image_aspect_ratio anyres_max_9 \ 45 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 46 | --mm_patch_merge_type spatial_unpad \ 47 | --bf16 True \ 48 | --run_name $RUN_NAME \ 49 | --output_dir /mnt/bn/vl-research/checkpoints/onevision/$RUN_NAME \ 50 | --num_train_epochs 1 \ 51 | --per_device_train_batch_size 1 \ 52 | --per_device_eval_batch_size 4 \ 53 | --gradient_accumulation_steps 2 \ 54 | --evaluation_strategy "no" \ 55 | --save_strategy "steps" \ 56 | --save_steps 1000 \ 57 | --save_total_limit 1 \ 58 | --learning_rate 1e-5 \ 59 | --weight_decay 0. \ 60 | --warmup_ratio 0.03 \ 61 | --lr_scheduler_type "cosine" \ 62 | --logging_steps 1 \ 63 | --tf32 True \ 64 | --model_max_length 32768 \ 65 | --gradient_checkpointing True \ 66 | --dataloader_num_workers 4 \ 67 | --lazy_preprocess True \ 68 | --report_to wandb \ 69 | --torch_compile True \ 70 | --torch_compile_backend "inductor" \ 71 | --dataloader_drop_last True \ 72 | --frames_upbound 32 73 | exit 0; 74 | -------------------------------------------------------------------------------- /scripts/train/mid_stage.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - json_path: /mnt/bn/vl-research/data/llava_instruct/blip558k_stage1.5_finetune_w_prompt.json # released in lmms-lab/LLaVA-ReCap-* 3 | sampling_strategy: all 4 | - json_path: /mnt/bn/vl-research/data/llava_instruct/coco118k_stage1.5_finetune_w_prompt.json # released in lmms-lab/LLaVA-ReCap-* 5 | sampling_strategy: all 6 | - json_path: /mnt/bn/vl-research/data/llava_instruct/cc3m_recap_data_prompt_v2.json # released in lmms-lab/LLaVA-ReCap-* 7 | sampling_strategy: all 8 | - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_tr_sft.json # released in lmms-lab/LLaVA-OneVision-Mid-Data 9 | sampling_strategy: all 10 | - json_path: /mnt/bn/vl-research/data/llava_instruct/instruct_azure_dc_zh_92K.json # not released, explained at https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main/scripts/train 11 | sampling_strategy: all 12 | - json_path: /mnt/bn/vl-research/data/llava_instruct/Evol-Instruct-GPT4-Turbo-143K.json # released in lmms-lab/LLaVA-OneVision-Mid-Data 13 | sampling_strategy: all 14 | - json_path: /mnt/bn/vl-research/data/llava_instruct/synthdog_zh/synthdog_zh_100k.json # released in lmms-lab/LLaVA-OneVision-Mid-Data 15 | sampling_strategy: all 16 | - json_path: /mnt/bn/vl-research/data/llava_instruct/synthdog_en/synthdog_en_100k.json # released in lmms-lab/LLaVA-OneVision-Mid-Data 17 | sampling_strategy: all -------------------------------------------------------------------------------- /scripts/train/pretrain_clip.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_SOCKET_IFNAME=eth0 5 | export NCCL_DEBUG=INFO 6 | 7 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 8 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 9 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 10 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 11 | 12 | ############### Pretrain ################ 13 | 14 | PROMPT_VERSION=plain 15 | 16 | BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" 17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 18 | 19 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 20 | llava/train/train_mem.py \ 21 | --deepspeed scripts/zero3.json \ 22 | --model_name_or_path ${LLM_VERSION} \ 23 | --version ${PROMPT_VERSION} \ 24 | --data_path /blip_558k/blip_558k_plain.json \ 25 | --image_folder /blip_558k/images \ 26 | --vision_tower ${VISION_MODEL_VERSION} \ 27 | --mm_tunable_parts="mm_mlp_adapter" \ 28 | --mm_vision_select_layer -2 \ 29 | --mm_projector_type mlp2x_gelu \ 30 | --mm_use_im_start_end False \ 31 | --mm_use_im_patch_token False \ 32 | --bf16 True \ 33 | --output_dir /checkpoints/projectors/${BASE_RUN_NAME} \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 16 \ 36 | --per_device_eval_batch_size 4 \ 37 | --gradient_accumulation_steps 1 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "no" \ 40 | --save_steps 50000 \ 41 | --learning_rate 1e-3 \ 42 | --weight_decay 0. \ 43 | --warmup_ratio 0.03 \ 44 | --lr_scheduler_type "cosine" \ 45 | --logging_steps 1 \ 46 | --tf32 True \ 47 | --model_max_length 8192 \ 48 | --gradient_checkpointing True \ 49 | --dataloader_num_workers 16 \ 50 | --lazy_preprocess True \ 51 | --report_to wandb \ 52 | --run_name $BASE_RUN_NAME \ 53 | --attn_implementation sdpa 54 | 55 | # You can delete the sdpa attn_implementation if you want to use flash attn -------------------------------------------------------------------------------- /scripts/train/pretrain_siglip.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_SOCKET_IFNAME=eth0 5 | export NCCL_DEBUG=INFO 6 | 7 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 8 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 9 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 10 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 11 | 12 | ############### Pretrain ################ 13 | 14 | PROMPT_VERSION=plain 15 | 16 | BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" 17 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 18 | 19 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 20 | llava/train/train_mem.py \ 21 | --deepspeed scripts/zero3.json \ 22 | --model_name_or_path ${LLM_VERSION} \ 23 | --version ${PROMPT_VERSION} \ 24 | --data_path /blip_558k/blip_558k_plain.json \ 25 | --image_folder /blip_558k/images \ 26 | --vision_tower ${VISION_MODEL_VERSION} \ 27 | --mm_tunable_parts="mm_mlp_adapter" \ 28 | --mm_vision_select_layer -2 \ 29 | --mm_projector_type mlp2x_gelu \ 30 | --mm_use_im_start_end False \ 31 | --mm_use_im_patch_token False \ 32 | --bf16 True \ 33 | --output_dir /checkpoints/projectors/${BASE_RUN_NAME} \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 16 \ 36 | --per_device_eval_batch_size 4 \ 37 | --gradient_accumulation_steps 1 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "no" \ 40 | --save_steps 50000 \ 41 | --learning_rate 1e-3 \ 42 | --weight_decay 0. \ 43 | --warmup_ratio 0.03 \ 44 | --lr_scheduler_type "cosine" \ 45 | --logging_steps 1 \ 46 | --tf32 True \ 47 | --model_max_length 8192 \ 48 | --gradient_checkpointing True \ 49 | --dataloader_num_workers 16 \ 50 | --lazy_preprocess True \ 51 | --report_to wandb \ 52 | --run_name $BASE_RUN_NAME \ 53 | --attn_implementation sdpa 54 | 55 | # You can delete the sdpa attn_implementation if you want to use flash attn -------------------------------------------------------------------------------- /scripts/video/demo/video_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/LLaVA-NeXT" 3 | 4 | if [ ! -e $ROOT_DIR ]; then 5 | echo "The root dir does not exist. Exiting the script." 6 | exit 1 7 | fi 8 | 9 | cd $ROOT_DIR 10 | 11 | export PYTHONWARNINGS=ignore 12 | export TOKENIZERS_PARALLELISM=false 13 | 14 | CKPT=$1 15 | CONV_MODE=$2 16 | FRAMES=$3 17 | POOL_STRIDE=$4 18 | POOL_MODE=$5 19 | NEWLINE_POSITION=$6 20 | OVERWRITE=$7 21 | VIDEO_PATH=$8 22 | 23 | 24 | if [ "$OVERWRITE" = False ]; then 25 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} 26 | 27 | else 28 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} 29 | fi 30 | 31 | python3 playground/demo/video_demo.py \ 32 | --model-path $CKPT \ 33 | --video_path ${VIDEO_PATH} \ 34 | --output_dir ./work_dirs/video_demo/$SAVE_DIR \ 35 | --output_name pred \ 36 | --chunk-idx $(($IDX - 1)) \ 37 | --overwrite ${OVERWRITE} \ 38 | --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ 39 | --for_get_frames_num $FRAMES \ 40 | --conv-mode $CONV_MODE \ 41 | --mm_spatial_pool_mode ${POOL_MODE:-average} \ 42 | --mm_newline_position ${NEWLINE_POSITION:-grid} \ 43 | --prompt "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes." -------------------------------------------------------------------------------- /scripts/video/eval/activitynet_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_DIR="root to LLaVA-NeXT-Video" 3 | 4 | if [ ! -e $ROOT_DIR ]; then 5 | echo "The root dir does not exist. Exiting the script." 6 | exit 1 7 | fi 8 | 9 | cd $ROOT_DIR 10 | 11 | export PYTHONWARNINGS=ignore 12 | export TOKENIZERS_PARALLELISM=false 13 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 14 | gpu_list="${CUDA_VISIBLE_DEVICES}" 15 | GPULIST=(${(s:,:)gpu_list}) 16 | 17 | CHUNKS=${#GPULIST[@]} 18 | echo "Using $CHUNKS GPUs" 19 | 20 | CKPT=$1 21 | CONV_MODE=$2 22 | FRAMES=$3 23 | OVERWRITE=$4 24 | PREDEFINED_CONFIGURE=$5 25 | mm_spatial_pool_stride=$6 26 | MODEL_MAX_LENGTH=${7:-0} 27 | 28 | CKPT=$1 29 | CONV_MODE=$2 30 | FRAMES=$3 31 | POOL_STRIDE=$4 32 | OVERWRITE=$5 33 | CHUNKS=${6:-1} 34 | 35 | PATCHIFY=False 36 | 37 | 38 | OPENAIKEY="INPUT YOUR OPENAI API" 39 | 40 | 41 | if [ "$OVERWRITE" = False ]; then 42 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} 43 | 44 | else 45 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} 46 | fi 47 | 48 | echo $SAVE_DIR 49 | 50 | # for IDX in {1..$CHUNKS}; do 51 | # GPU_ID=${GPULIST[$IDX]} # Note: Zsh arrays are 1-indexed by default 52 | 53 | # # GPU_FREE=0 54 | # # while [ $GPU_FREE -eq 0 ]; do 55 | # # # Using nvidia-smi to get the memory usage of the GPU with ID $GPU_ID 56 | # # # Parsing the output to extract the memory usage, and checking if it is "0" 57 | # # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') 58 | 59 | # # if [ "$MEM_USAGE" -eq 0 ]; then 60 | # # GPU_FREE=1 61 | # # echo "GPU $GPU_ID is free." 62 | # # else 63 | # # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB. Checking again in 100 seconds..." 64 | # # sleep 100 65 | # # fi 66 | # # done 67 | 68 | # echo "Running on GPU $GPU_ID" 69 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 llavavid/eval/model_activitynet_qa.py \ 70 | # --model-path $CKPT \ 71 | # --video_dir ./data/llava_video/ActivityNet-QA/all_test \ 72 | # --gt_file_question ./data/llava_video/ActivityNet-QA/test_q.json \ 73 | # --gt_file_answers ./data/llava_videoActivityNet-QA/test_a.json \ 74 | # --output_dir ./work_dirs/eval_activitynet/$SAVE_DIR \ 75 | # --output_name pred \ 76 | # --num-chunks $CHUNKS \ 77 | # --chunk-idx $(($IDX - 1)) \ 78 | # --overwrite ${OVERWRITE} \ 79 | # --patchify_video_feature ${PATCHIFY} \ 80 | # --predefined_configure ${PREDEFINED_CONFIGURE} \ 81 | # --mm_spatial_pool_stride ${mm_spatial_pool_stride:-4} \ 82 | # --for_get_frames_num $FRAMES \ 83 | # --model-max-length ${MODEL_MAX_LENGTH:-0} \ 84 | # --conv-mode $CONV_MODE & 85 | 86 | # done 87 | 88 | # wait 89 | 90 | python3 llava/eval/eval_activitynet_qa.py \ 91 | --pred_path ./work_dirs/eval_activitynet/$SAVE_DIR \ 92 | --output_dir ./work_dirs/eval_activitynet/$SAVE_DIR/results \ 93 | --output_json ./work_dirs/eval_activitynet/$SAVE_DIR/results.json \ 94 | --num_chunks $CHUNKS \ 95 | --api_key $OPENAIKEY \ 96 | # --num_tasks 16 \ -------------------------------------------------------------------------------- /scripts/video/eval/video_description_from_t2v.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/llava-next-video" 3 | 4 | if [ ! -e $ROOT_DIR ]; then 5 | echo "The root dir does not exist. Exiting the script." 6 | exit 1 7 | fi 8 | 9 | cd $ROOT_DIR 10 | 11 | export PYTHONWARNINGS=ignore 12 | export TOKENIZERS_PARALLELISM=false 13 | 14 | CKPT=$1 15 | CONV_MODE=$2 16 | FRAMES=$3 17 | POOL_STRIDE=$4 18 | OVERWRITE=$5 19 | CHUNKS=${6:-1} 20 | DO_CENTER_CROP=${7:-False} 21 | 22 | echo "Using $CHUNKS GPUs" 23 | 24 | LOAD_8BIT=False 25 | 26 | 27 | if [ "$OVERWRITE" = False ]; then 28 | if [ "$MODEL_MAX_LENGTH" = 0 ]; then 29 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_overwrite_${OVERWRITE} 30 | else 31 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_overwrite_${OVERWRITE} 32 | fi 33 | else 34 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} 35 | fi 36 | 37 | SAVE_DIR=${SAVE_DIR}_do_center_crop_${DO_CENTER_CROP} 38 | # Assuming GPULIST is a bash array containing your GPUs 39 | GPULIST=(0 1 2 3 4 5 6 7) 40 | # GPULIST=(0) 41 | 42 | # Get the number of GPUs 43 | NUM_GPUS=${#GPULIST[@]} 44 | 45 | # Calculate GPUs per chunk 46 | GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS)) 47 | 48 | 49 | for IDX in $(seq 1 $CHUNKS); do 50 | START=$(((IDX-1) * GPUS_PER_CHUNK)) 51 | LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index 52 | 53 | CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) 54 | 55 | # Convert the chunk GPUs array to a comma-separated string 56 | CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") 57 | 58 | # ALL_GPUS_FREE=0 59 | # while [ $ALL_GPUS_FREE -eq 0 ]; do 60 | # ALL_GPUS_FREE=1 # Assume all GPUs are free initially 61 | 62 | # for GPU_ID in $CHUNK_GPUS; do 63 | # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') 64 | 65 | # # Assuming a GPU is considered free if its memory usage is less than 100 MiB 66 | # if [ "$MEM_USAGE" -ge 100 ]; then 67 | # ALL_GPUS_FREE=0 68 | # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." 69 | # break # Exit the loop early as we found a GPU that is not free 70 | # fi 71 | # done 72 | 73 | # if [ $ALL_GPUS_FREE -eq 0 ]; then 74 | # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." 75 | # sleep 100 76 | # fi 77 | # done 78 | 79 | echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" 80 | CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_description_from_t2v.py \ 81 | --model-path $CKPT \ 82 | --gt_file /mnt/bn/vl-research-1t/tuyen/webvid_hdvg_movie_pond5_for_captioning_evaluation/webvid_hdvg_movie_pond5_for_captioning_evaluation.processed.csv \ 83 | --output_dir ./work_dirs/eval_video_description_from_t2v/$SAVE_DIR \ 84 | --output_name pred \ 85 | --num-chunks $CHUNKS \ 86 | --chunk-idx $(($IDX - 1)) \ 87 | --overwrite ${OVERWRITE} \ 88 | --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ 89 | --for_get_frames_num $FRAMES \ 90 | --load_8bit $LOAD_8BIT \ 91 | --do_center_crop $DO_CENTER_CROP \ 92 | --conv-mode $CONV_MODE & 93 | done 94 | 95 | wait 96 | 97 | cat ${ROOT_DIR}/work_dirs/eval_video_description_from_t2v/$SAVE_DIR/${CHUNKS}* > ${ROOT_DIR}/work_dirs/eval_video_description_from_t2v/$SAVE_DIR/pred.json 98 | 99 | -------------------------------------------------------------------------------- /scripts/video/eval/video_detail_description_eval_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_DIR="root to LLaVA-NeXT-Video" 3 | 4 | if [ ! -e $ROOT_DIR ]; then 5 | echo "The root dir does not exist. Exiting the script." 6 | exit 1 7 | fi 8 | 9 | cd $ROOT_DIR 10 | 11 | export PYTHONWARNINGS=ignore 12 | export TOKENIZERS_PARALLELISM=false 13 | 14 | OPENAIKEY="INPUT YOUR OPENAI API" 15 | 16 | SAVE_DIR=$1 17 | 18 | python3 llava/eval/evaluate_benchmark_video_detail_description.py \ 19 | --pred_path ./work_dirs/eval_video_detail_description/$SAVE_DIR/pred.json \ 20 | --output_dir ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results \ 21 | --output_json ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results.json \ 22 | --num_chunks 1 \ 23 | --num_tasks 16 \ 24 | --api_key $OPENAIKEY \ -------------------------------------------------------------------------------- /scripts/video/eval/video_detail_description_eval_shard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/llava-next-video" 3 | 4 | if [ ! -e $ROOT_DIR ]; then 5 | echo "The root dir does not exist. Exiting the script." 6 | exit 1 7 | fi 8 | 9 | cd $ROOT_DIR 10 | 11 | export PYTHONWARNINGS=ignore 12 | export TOKENIZERS_PARALLELISM=false 13 | 14 | OPENAIKEY="INPUT YOUR OPENAI API" 15 | 16 | CKPT=$1 17 | CONV_MODE=$2 18 | FRAMES=$3 19 | POOL_STRIDE=$4 20 | OVERWRITE=$5 21 | CHUNKS=${6:-1} 22 | 23 | echo "Using $CHUNKS GPUs" 24 | 25 | if [ "$OVERWRITE" = False ]; then 26 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} 27 | 28 | else 29 | SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} 30 | fi 31 | 32 | # Assuming GPULIST is a bash array containing your GPUs 33 | GPULIST=(0 1 2 3 4 5 6 7) 34 | 35 | # Get the number of GPUs 36 | NUM_GPUS=${#GPULIST[@]} 37 | 38 | # Calculate GPUs per chunk 39 | GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS)) 40 | 41 | 42 | for IDX in $(seq 1 $CHUNKS); do 43 | START=$(((IDX-1) * GPUS_PER_CHUNK)) 44 | LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index 45 | 46 | CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) 47 | 48 | # Convert the chunk GPUs array to a comma-separated string 49 | CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") 50 | 51 | # ALL_GPUS_FREE=0 52 | # while [ $ALL_GPUS_FREE -eq 0 ]; do 53 | # ALL_GPUS_FREE=1 # Assume all GPUs are free initially 54 | 55 | # for GPU_ID in $CHUNK_GPUS; do 56 | # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') 57 | 58 | # # Assuming a GPU is considered free if its memory usage is less than 100 MiB 59 | # if [ "$MEM_USAGE" -ge 100 ]; then 60 | # ALL_GPUS_FREE=0 61 | # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." 62 | # break # Exit the loop early as we found a GPU that is not free 63 | # fi 64 | # done 65 | 66 | # if [ $ALL_GPUS_FREE -eq 0 ]; then 67 | # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." 68 | # sleep 100 69 | # fi 70 | # done 71 | 72 | echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" 73 | CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_detail_description.py \ 74 | --model-path $CKPT \ 75 | --video_dir ./data/llava_video/video-chatgpt/evaluation/Test_Videos/ \ 76 | --output_dir ./work_dirs/eval_video_detail_description/$SAVE_DIR \ 77 | --output_name pred \ 78 | --num-chunks $CHUNKS \ 79 | --chunk-idx $(($IDX - 1)) \ 80 | --overwrite ${OVERWRITE} \ 81 | --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ 82 | --for_get_frames_num $FRAMES \ 83 | --conv-mode $CONV_MODE & 84 | done 85 | 86 | wait 87 | 88 | python3 llava/eval/evaluate_benchmark_video_detail_description.py \ 89 | --pred_path ./work_dirs/eval_video_detail_description/$SAVE_DIR \ 90 | --output_dir ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results \ 91 | --output_json ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results.json \ 92 | --num_chunks $CHUNKS \ 93 | --num_tasks 16 \ 94 | --api_key $OPENAIKEY \ 95 | 96 | -------------------------------------------------------------------------------- /scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set up the data folder 4 | IMAGE_FOLDER="XXX" 5 | VIDEO_FOLDER="XXX" 6 | DATA_YAML="XXX" # e.g exp.yaml 7 | 8 | ############### Prepare Envs ################# 9 | python3 -m pip install flash-attn --no-build-isolation 10 | alias python=python3 11 | ############### Show Envs #################### 12 | 13 | nvidia-smi 14 | 15 | ################ Arnold Jobs ################ 16 | 17 | LLM_VERSION="Qwen/Qwen2-72B-Instruct" 18 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 19 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 20 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 21 | 22 | 23 | BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-72B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" 24 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 25 | 26 | # Stage 2 27 | PROMPT_VERSION="qwen_1_5" 28 | MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9" 29 | PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-72b-ov-si" 30 | echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" 31 | echo "MID_RUN_NAME: ${MID_RUN_NAME}" 32 | 33 | 34 | # ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ 35 | deepspeed --master_port 30000 \ 36 | llava/train/train_mem.py \ 37 | --deepspeed scripts/zero3.json \ 38 | --model_name_or_path $PREV_STAGE_CHECKPOINT \ 39 | --version $PROMPT_VERSION \ 40 | --data_path $DATA_YAML \ 41 | --image_folder $IMAGE_FOLDER \ 42 | --video_folder $VIDEO_FOLDER \ 43 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 44 | --mm_vision_tower_lr=2e-6 \ 45 | --vision_tower ${VISION_MODEL_VERSION} \ 46 | --mm_projector_type mlp2x_gelu \ 47 | --mm_vision_select_layer -2 \ 48 | --mm_use_im_start_end False \ 49 | --mm_use_im_patch_token False \ 50 | --group_by_modality_length True \ 51 | --image_aspect_ratio anyres_max_9 \ 52 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 53 | --mm_patch_merge_type spatial_unpad \ 54 | --bf16 True \ 55 | --run_name $MID_RUN_NAME \ 56 | --output_dir ./work_dirs/$MID_RUN_NAME \ 57 | --num_train_epochs 1 \ 58 | --per_device_train_batch_size 1 \ 59 | --per_device_eval_batch_size 4 \ 60 | --gradient_accumulation_steps 2 \ 61 | --evaluation_strategy "no" \ 62 | --save_strategy "steps" \ 63 | --save_steps 500 \ 64 | --save_total_limit 1 \ 65 | --learning_rate 1e-5 \ 66 | --weight_decay 0. \ 67 | --warmup_ratio 0.03 \ 68 | --lr_scheduler_type "cosine" \ 69 | --logging_steps 1 \ 70 | --tf32 True \ 71 | --model_max_length 32768 \ 72 | --gradient_checkpointing True \ 73 | --dataloader_num_workers 2 \ 74 | --lazy_preprocess True \ 75 | --report_to wandb \ 76 | --torch_compile True \ 77 | --torch_compile_backend "inductor" \ 78 | --dataloader_drop_last True \ 79 | --frames_upbound 32 \ 80 | --mm_newline_position grid \ 81 | --add_time_instruction True \ 82 | --force_sample True \ 83 | --mm_spatial_pool_stride 2 84 | exit 0; -------------------------------------------------------------------------------- /scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set up the data folder 4 | IMAGE_FOLDER="XXX" 5 | VIDEO_FOLDER="XXX" 6 | DATA_YAML="XXX" # e.g exp.yaml 7 | 8 | ############### Prepare Envs ################# 9 | python3 -m pip install flash-attn --no-build-isolation 10 | alias python=python3 11 | ############### Show Envs #################### 12 | 13 | nvidia-smi 14 | 15 | ################ Arnold Jobs ################ 16 | 17 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 18 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 19 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 20 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 21 | # 22 | 23 | BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" 24 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 25 | 26 | # Stage 2 27 | PROMPT_VERSION="qwen_1_5" 28 | MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9" 29 | PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-7b-ov-si" 30 | echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" 31 | echo "MID_RUN_NAME: ${MID_RUN_NAME}" 32 | 33 | 34 | # ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ 35 | deepspeed --master_port 30000 \ 36 | llava/train/train_mem.py \ 37 | --deepspeed scripts/zero3.json \ 38 | --model_name_or_path $PREV_STAGE_CHECKPOINT \ 39 | --version $PROMPT_VERSION \ 40 | --data_path $DATA_YAML \ 41 | --image_folder $IMAGE_FOLDER \ 42 | --video_folder $VIDEO_FOLDER \ 43 | --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ 44 | --mm_vision_tower_lr=2e-6 \ 45 | --vision_tower ${VISION_MODEL_VERSION} \ 46 | --mm_projector_type mlp2x_gelu \ 47 | --mm_vision_select_layer -2 \ 48 | --mm_use_im_start_end False \ 49 | --mm_use_im_patch_token False \ 50 | --group_by_modality_length True \ 51 | --image_aspect_ratio anyres_max_9 \ 52 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 53 | --mm_patch_merge_type spatial_unpad \ 54 | --bf16 True \ 55 | --run_name $MID_RUN_NAME \ 56 | --output_dir ./work_dirs/$MID_RUN_NAME \ 57 | --num_train_epochs 1 \ 58 | --per_device_train_batch_size 1 \ 59 | --per_device_eval_batch_size 4 \ 60 | --gradient_accumulation_steps 2 \ 61 | --evaluation_strategy "no" \ 62 | --save_strategy "steps" \ 63 | --save_steps 500 \ 64 | --save_total_limit 1 \ 65 | --learning_rate 1e-5 \ 66 | --weight_decay 0. \ 67 | --warmup_ratio 0.03 \ 68 | --lr_scheduler_type "cosine" \ 69 | --logging_steps 1 \ 70 | --tf32 True \ 71 | --model_max_length 32768 \ 72 | --gradient_checkpointing True \ 73 | --dataloader_num_workers 2 \ 74 | --lazy_preprocess True \ 75 | --report_to wandb \ 76 | --torch_compile True \ 77 | --torch_compile_backend "inductor" \ 78 | --dataloader_drop_last True \ 79 | --frames_upbound 64 \ 80 | --mm_newline_position grid \ 81 | --add_time_instruction True \ 82 | --force_sample True \ 83 | --mm_spatial_pool_stride 2 84 | exit 0; -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_fused_adamw.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": true, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu", 20 | "pin_memory": true 21 | }, 22 | "offload_param": { 23 | "device": "cpu", 24 | "pin_memory": true 25 | }, 26 | "overlap_comm": true, 27 | "contiguous_gradients": true, 28 | "sub_group_size": 1e9, 29 | "reduce_bucket_size": "auto" 30 | } 31 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /scripts/zero3pp.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "none", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "none", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "zero_quantized_weights": true, 36 | "zero_hpz_partition_size": 16, 37 | "zero_quantized_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | 47 | "gradient_accumulation_steps": "auto", 48 | "gradient_clipping": "auto", 49 | "steps_per_print": 100, 50 | "train_batch_size": "auto", 51 | "train_micro_batch_size_per_gpu": "auto", 52 | "wall_clock_breakdown": false 53 | } -------------------------------------------------------------------------------- /trl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | __version__ = "0.7.11.dev0" 4 | 5 | from .core import set_seed 6 | from .environment import TextEnvironment, TextHistory 7 | from .extras import BestOfNSampler 8 | from .import_utils import ( 9 | is_bitsandbytes_available, 10 | is_diffusers_available, 11 | is_npu_available, 12 | is_peft_available, 13 | is_wandb_available, 14 | is_xpu_available, 15 | ) 16 | from .models import ( 17 | AutoModelForCausalLMWithValueHead, 18 | AutoModelForSeq2SeqLMWithValueHead, 19 | PreTrainedModelWrapper, 20 | create_reference_model, 21 | setup_chat_format, 22 | ) 23 | from .trainer import ( 24 | DataCollatorForCompletionOnlyLM, 25 | DPOTrainer, 26 | IterativeSFTTrainer, 27 | ModelConfig, 28 | PPOConfig, 29 | PPOTrainer, 30 | RewardConfig, 31 | RewardTrainer, 32 | SFTTrainer, 33 | ) 34 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config 35 | 36 | 37 | if is_diffusers_available(): 38 | from .models import ( 39 | DDPOPipelineOutput, 40 | DDPOSchedulerOutput, 41 | DDPOStableDiffusionPipeline, 42 | DefaultDDPOStableDiffusionPipeline, 43 | ) 44 | from .trainer import DDPOConfig, DDPOTrainer 45 | -------------------------------------------------------------------------------- /trl/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .base_environment import TextEnvironment, TextHistory 4 | -------------------------------------------------------------------------------- /trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .best_of_n_sampler import BestOfNSampler 17 | -------------------------------------------------------------------------------- /trl/extras/dataset_formatting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Literal, Optional, Union 3 | 4 | from datasets import Dataset, Value 5 | from transformers import AutoTokenizer 6 | 7 | from ..trainer.utils import ConstantLengthDataset 8 | 9 | 10 | FORMAT_MAPPING = { 11 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], 12 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 13 | } 14 | 15 | 16 | def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): 17 | r""" 18 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer 19 | apply chat template to the dataset 20 | """ 21 | 22 | def format_dataset(examples): 23 | if isinstance(examples[messages_field][0], list): 24 | output_texts = [] 25 | for i in range(len(examples[messages_field])): 26 | output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) 27 | return output_texts 28 | else: 29 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) 30 | 31 | return format_dataset 32 | 33 | 34 | def instructions_formatting_function(tokenizer: AutoTokenizer): 35 | r""" 36 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer 37 | apply chat template to the dataset 38 | """ 39 | 40 | def format_dataset(examples): 41 | if isinstance(examples["prompt"], list): 42 | output_texts = [] 43 | for i in range(len(examples["prompt"])): 44 | converted_sample = [ 45 | {"role": "user", "content": examples["prompt"][i]}, 46 | {"role": "assistant", "content": examples["completion"][i]}, 47 | ] 48 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) 49 | return output_texts 50 | else: 51 | converted_sample = [ 52 | {"role": "user", "content": examples["prompt"]}, 53 | {"role": "assistant", "content": examples["completion"]}, 54 | ] 55 | return tokenizer.apply_chat_template(converted_sample, tokenize=False) 56 | 57 | return format_dataset 58 | 59 | 60 | def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]: 61 | r""" 62 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are: 63 | - `ChatML` with [{"role": str, "content": str}] 64 | - `instruction` with [{"prompt": str, "completion": str}] 65 | 66 | Args: 67 | dataset (Dataset): User dataset 68 | tokenizer (AutoTokenizer): Tokenizer used for formatting 69 | 70 | Returns: 71 | Callable: Formatting function if the dataset format is supported else None 72 | """ 73 | if isinstance(dataset, Dataset): 74 | if "messages" in dataset.features: 75 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: 76 | logging.info("Formatting dataset with chatml format") 77 | return conversations_formatting_function(tokenizer, "messages") 78 | if "conversations" in dataset.features: 79 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: 80 | logging.info("Formatting dataset with chatml format") 81 | return conversations_formatting_function(tokenizer, "conversations") 82 | elif dataset.features == FORMAT_MAPPING["instruction"]: 83 | logging.info("Formatting dataset with instruction format") 84 | return instructions_formatting_function(tokenizer) 85 | 86 | return None 87 | -------------------------------------------------------------------------------- /trl/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import sys 16 | 17 | 18 | if sys.version_info < (3, 8): 19 | _is_python_greater_3_8 = False 20 | else: 21 | _is_python_greater_3_8 = True 22 | 23 | 24 | def is_peft_available() -> bool: 25 | return importlib.util.find_spec("peft") is not None 26 | 27 | 28 | def is_unsloth_available() -> bool: 29 | return importlib.util.find_spec("unsloth") is not None 30 | 31 | 32 | def is_accelerate_greater_20_0() -> bool: 33 | if _is_python_greater_3_8: 34 | from importlib.metadata import version 35 | 36 | accelerate_version = version("accelerate") 37 | else: 38 | import pkg_resources 39 | 40 | accelerate_version = pkg_resources.get_distribution("accelerate").version 41 | return accelerate_version >= "0.20.0" 42 | 43 | 44 | def is_transformers_greater_than(version: str) -> bool: 45 | _transformers_version = importlib.metadata.version("transformers") 46 | return _transformers_version > version 47 | 48 | 49 | def is_torch_greater_2_0() -> bool: 50 | if _is_python_greater_3_8: 51 | from importlib.metadata import version 52 | 53 | torch_version = version("torch") 54 | else: 55 | import pkg_resources 56 | 57 | torch_version = pkg_resources.get_distribution("torch").version 58 | return torch_version >= "2.0" 59 | 60 | 61 | def is_diffusers_available() -> bool: 62 | return importlib.util.find_spec("diffusers") is not None 63 | 64 | 65 | def is_bitsandbytes_available() -> bool: 66 | import torch 67 | 68 | # bnb can be imported without GPU but is not usable. 69 | return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() 70 | 71 | 72 | def is_torchvision_available() -> bool: 73 | return importlib.util.find_spec("torchvision") is not None 74 | 75 | 76 | def is_rich_available() -> bool: 77 | return importlib.util.find_spec("rich") is not None 78 | 79 | 80 | def is_wandb_available() -> bool: 81 | return importlib.util.find_spec("wandb") is not None 82 | 83 | 84 | def is_xpu_available() -> bool: 85 | if is_accelerate_greater_20_0(): 86 | import accelerate 87 | 88 | return accelerate.utils.is_xpu_available() 89 | else: 90 | if importlib.util.find_spec("intel_extension_for_pytorch") is None: 91 | return False 92 | try: 93 | import torch 94 | 95 | return hasattr(torch, "xpu") and torch.xpu.is_available() 96 | except RuntimeError: 97 | return False 98 | 99 | 100 | def is_npu_available() -> bool: 101 | """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" 102 | if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: 103 | return False 104 | 105 | import torch 106 | import torch_npu # noqa: F401 107 | 108 | return hasattr(torch, "npu") and torch.npu.is_available() 109 | -------------------------------------------------------------------------------- /trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .modeling_base import PreTrainedModelWrapper, create_reference_model 17 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 18 | from .utils import setup_chat_format 19 | 20 | 21 | SUPPORTED_ARCHITECTURES = ( 22 | AutoModelForCausalLMWithValueHead, 23 | AutoModelForSeq2SeqLMWithValueHead, 24 | ) 25 | 26 | from ..import_utils import is_diffusers_available 27 | 28 | 29 | if is_diffusers_available(): 30 | from .modeling_sd_base import ( 31 | DDPOPipelineOutput, 32 | DDPOSchedulerOutput, 33 | DDPOStableDiffusionPipeline, 34 | DefaultDDPOStableDiffusionPipeline, 35 | ) 36 | -------------------------------------------------------------------------------- /trl/models/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Tuple 3 | 4 | from transformers import PreTrainedModel, PreTrainedTokenizer 5 | 6 | 7 | # TODO: Add Abstract Base Class if more formats are added 8 | @dataclass 9 | class ChatMlSpecialTokens: 10 | """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" 11 | 12 | bos_token: str = "<|im_start|>" 13 | eos_token: str = "<|im_end|>" 14 | pad_token: str = "<|im_end|>" 15 | 16 | @property 17 | def system(self): 18 | return f"{self.bos_token}system" 19 | 20 | @property 21 | def user(self): 22 | return f"{self.bos_token}user" 23 | 24 | @property 25 | def assistant(self): 26 | return f"{self.bos_token}assistant" 27 | 28 | @property 29 | def chat_template(self): 30 | return ( 31 | "{% for message in messages %}" 32 | f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" 33 | "{% endfor %}" 34 | "{% if add_generation_prompt %}" 35 | f"{{{{ '{self.assistant}\n' }}}}" 36 | "{% endif %}" 37 | ) 38 | 39 | 40 | FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} 41 | 42 | 43 | def setup_chat_format( 44 | model: PreTrainedModel, 45 | tokenizer: PreTrainedTokenizer, 46 | format: Optional[Literal["chatml"]] = "chatml", 47 | resize_to_multiple_of: Optional[int] = None, 48 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 49 | """ 50 | Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. 51 | 52 | Args: 53 | model (`~transformers.PreTrainedModel`): The model to be modified. 54 | tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. 55 | format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". 56 | resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. 57 | Returns: 58 | model (`~transformers.PreTrainedModel`): The modified model. 59 | tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. 60 | """ 61 | # check if format available and retrieve 62 | if format not in FORMAT_MAPPING: 63 | raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") 64 | 65 | chat_format = FORMAT_MAPPING[format]() 66 | 67 | # set special tokens and them 68 | tokenizer.eos_token = chat_format.eos_token 69 | tokenizer.pad_token = chat_format.pad_token 70 | tokenizer.bos_token = chat_format.bos_token 71 | tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) 72 | # set chat format for tokenizer 73 | tokenizer.chat_template = chat_format.chat_template 74 | 75 | # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 76 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) 77 | # Make sure to update the generation config to use the new eos & bos token 78 | if getattr(model, "generation_config", None) is not None: 79 | model.generation_config.bos_token_id = tokenizer.bos_token_id 80 | model.generation_config.eos_token_id = tokenizer.eos_token_id 81 | model.generation_config.pad_token_id = tokenizer.pad_token_id 82 | 83 | return model, tokenizer 84 | -------------------------------------------------------------------------------- /trl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # There is a circular import in the PPOTrainer if we let isort sort these 18 | # isort: off 19 | from .utils import ( 20 | AdaptiveKLController, 21 | FixedKLController, 22 | ConstantLengthDataset, 23 | DataCollatorForCompletionOnlyLM, 24 | RunningMoments, 25 | disable_dropout_in_model, 26 | peft_module_casting_to_bf16, 27 | ) 28 | 29 | # isort: on 30 | 31 | from ..import_utils import is_diffusers_available 32 | from .base import BaseTrainer 33 | from .ddpo_config import DDPOConfig 34 | 35 | 36 | if is_diffusers_available(): 37 | from .ddpo_trainer import DDPOTrainer 38 | 39 | from .dpo_trainer import DPOTrainer 40 | from .iterative_sft_trainer import IterativeSFTTrainer 41 | from .model_config import ModelConfig 42 | from .ppo_config import PPOConfig 43 | from .ppo_trainer import PPOTrainer 44 | from .reward_config import RewardConfig 45 | from .reward_trainer import RewardTrainer, compute_accuracy 46 | from .sft_trainer import SFTTrainer 47 | -------------------------------------------------------------------------------- /trl/trainer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from huggingface_hub import PyTorchModelHubMixin 16 | 17 | 18 | class BaseTrainer(PyTorchModelHubMixin): 19 | r""" 20 | Base class for all trainers - this base class implements the basic functions that we 21 | need for a trainer. 22 | 23 | The trainer needs to have the following functions: 24 | - step: takes in a batch of data and performs a step of training 25 | - loss: takes in a batch of data and returns the loss 26 | - compute_rewards: takes in a batch of data and returns the rewards 27 | - _build_models_and_tokenizer: builds the models and tokenizer 28 | - _build_dataset: builds the dataset 29 | Each user is expected to implement their own trainer class that inherits from this base 30 | if they want to use a new training algorithm. 31 | """ 32 | 33 | def __init__(self, config): 34 | self.config = config 35 | 36 | def step(self, *args): 37 | raise NotImplementedError("Not implemented") 38 | 39 | def loss(self, *args): 40 | raise NotImplementedError("Not implemented") 41 | 42 | def compute_rewards(self, *args): 43 | raise NotImplementedError("Not implemented") 44 | 45 | def _save_pretrained(self, save_directory): 46 | raise NotImplementedError("Not implemented") 47 | -------------------------------------------------------------------------------- /trl/trainer/model_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | from ..core import flatten_dict 5 | 6 | 7 | @dataclass 8 | class ModelConfig: 9 | """ 10 | Arguments which define the model and tokenizer to load. 11 | """ 12 | 13 | model_name_or_path: Optional[str] = field( 14 | default=None, 15 | metadata={"help": ("The model checkpoint for weights initialization.")}, 16 | ) 17 | model_revision: str = field( 18 | default="main", 19 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 20 | ) 21 | torch_dtype: Optional[str] = field( 22 | default=None, 23 | metadata={ 24 | "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), 25 | "choices": ["auto", "bfloat16", "float16", "float32"], 26 | }, 27 | ) 28 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) 29 | attn_implementation: Optional[str] = field( 30 | default=None, 31 | metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, 32 | ) 33 | use_peft: bool = field( 34 | default=False, 35 | metadata={"help": ("Whether to use PEFT or not for training.")}, 36 | ) 37 | lora_r: Optional[int] = field( 38 | default=16, 39 | metadata={"help": ("LoRA R value.")}, 40 | ) 41 | lora_alpha: Optional[int] = field( 42 | default=32, 43 | metadata={"help": ("LoRA alpha.")}, 44 | ) 45 | lora_dropout: Optional[float] = field( 46 | default=0.05, 47 | metadata={"help": ("LoRA dropout.")}, 48 | ) 49 | lora_target_modules: Optional[List[str]] = field( 50 | default=None, 51 | metadata={"help": ("LoRA target modules.")}, 52 | ) 53 | lora_modules_to_save: Optional[List[str]] = field( 54 | default=None, 55 | metadata={"help": ("Model layers to unfreeze & train")}, 56 | ) 57 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) 58 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) 59 | 60 | bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) 61 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) 62 | 63 | def to_dict(self): 64 | output_dict = {} 65 | for key, value in self.__dict__.items(): 66 | output_dict[key] = value 67 | return flatten_dict(output_dict) 68 | 69 | def __post_init__(self): 70 | if self.load_in_8bit and self.load_in_4bit: 71 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 72 | -------------------------------------------------------------------------------- /trl/trainer/reward_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class RewardConfig(TrainingArguments): 23 | """ 24 | RewardConfig collects all training arguments related to the [`RewardTrainer`] class. 25 | 26 | Using [`HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | max_length (`int`, *optional*, defaults to `None`): 32 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 33 | gradient_checkpointing (`bool`, *optional*, defaults to `True`): 34 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 35 | """ 36 | 37 | max_length: Optional[int] = None 38 | """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" 39 | --------------------------------------------------------------------------------