├── .gitignore ├── LICENSE ├── README.md ├── data_filtering ├── baselines.py ├── baselines │ ├── __init__.py │ ├── additional_text_filter.py │ ├── apply_filter.py │ ├── image_based_clustering.md │ ├── image_based_clustering.py │ └── utils.py ├── requirements.txt └── resharder.py ├── data_prepare └── split_mammoth_10m.py ├── docs ├── DATA_Filter.md └── Eval.md ├── mm_sequence_packing ├── multiprocess_sequence_packing_image_to_json.sh ├── multiprocess_sequence_packing_image_to_pil.sh ├── sequence_packing_image_to_json.py └── sequence_packing_image_to_pil.py ├── prismatic-vlms ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── fine_tune.sh ├── fine_tune_mammoth.sh ├── prismatic │ ├── __init__.py │ ├── conf │ │ ├── __init__.py │ │ ├── datasets.py │ │ └── models.py │ ├── models │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── llm │ │ │ │ ├── __init__.py │ │ │ │ ├── base_llm.py │ │ │ │ ├── llama2.py │ │ │ │ ├── llama3.py │ │ │ │ ├── mistral.py │ │ │ │ ├── phi3.py │ │ │ │ ├── prompting │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base_prompter.py │ │ │ │ │ ├── llama2_chat_prompter.py │ │ │ │ │ ├── llama3_chat_prompter.py │ │ │ │ │ ├── phi_3_prompter.py │ │ │ │ │ ├── qwen2_prompter.py │ │ │ │ │ └── vicuna_v15_prompter.py │ │ │ │ └── qwen2.py │ │ │ └── vision │ │ │ │ ├── __init__.py │ │ │ │ ├── base_vision.py │ │ │ │ ├── clip_vit.py │ │ │ │ ├── dinoclip_vit.py │ │ │ │ ├── dinosiglip_vit.py │ │ │ │ ├── dinov2_vit.py │ │ │ │ ├── in1k_vit.py │ │ │ │ └── siglip_vit.py │ │ ├── load.py │ │ ├── materialize.py │ │ ├── registry.py │ │ └── vlms │ │ │ ├── __init__.py │ │ │ ├── base_vlm.py │ │ │ └── prismatic.py │ ├── overwatch │ │ ├── __init__.py │ │ └── overwatch.py │ ├── preprocessing │ │ ├── __init__.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ └── datasets.py │ │ ├── download.py │ │ └── materialize.py │ ├── py.typed │ ├── training │ │ ├── __init__.py │ │ ├── materialize.py │ │ ├── metrics.py │ │ └── strategies │ │ │ ├── __init__.py │ │ │ ├── base_strategy.py │ │ │ ├── ddp.py │ │ │ └── fsdp.py │ └── util │ │ ├── __init__.py │ │ ├── batching_utils.py │ │ ├── data_utils.py │ │ ├── nn_utils.py │ │ └── torch_utils.py ├── pyproject.toml ├── scripts │ ├── additional-datasets │ │ ├── lrv_instruct.py │ │ └── lvis_instruct_4v.py │ ├── generate.py │ ├── preprocess.py │ └── pretrain.py └── train.sh ├── test.py └── vlm-evaluation ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── eval.sh ├── images └── 03-evaluation-suite-med-res.png ├── pyproject.toml ├── scripts ├── coco_score.py ├── datasets │ └── prepare.py ├── evaluate.py ├── interactive_demo.py └── score.py └── vlm_eval ├── __init__.py ├── conf ├── __init__.py └── datasets.py ├── models ├── __init__.py ├── instructblip.py ├── llava.py ├── prismatic.py └── qwen2vl.py ├── overwatch ├── __init__.py └── overwatch.py ├── serve ├── __init__.py ├── controller.py ├── examples │ ├── cows_in_pasture.png │ └── monkey_knives.png └── gradio_web_server.py ├── tasks ├── __init__.py ├── builders.py ├── download.py ├── harnesses │ ├── __init__.py │ ├── ai2d.py │ ├── gqa.py │ ├── mantis.py │ ├── mathvista.py │ ├── mmbench.py │ ├── mmlu.py │ ├── mmmu.py │ ├── mmstar.py │ ├── mscoco_karpathy.py │ ├── ocidref.py │ ├── okvqa.py │ ├── pope.py │ ├── refcoco.py │ ├── seedbench.py │ ├── tallyqa.py │ ├── textvqa.py │ ├── vizwiz.py │ ├── vqav2.py │ └── vsr.py └── registry.py └── util ├── __init__.py ├── evaluation ├── __init__.py ├── gqa │ ├── __init__.py │ └── eval.py ├── mmmu │ ├── __init__.py │ └── eval.py ├── nocaps │ └── metrics.py ├── textvqa │ ├── __init__.py │ └── m4c_evaluators.py ├── vizwiz │ ├── __init__.py │ └── eval.py └── vqav2 │ ├── __init__.py │ └── eval.py ├── interfaces.py ├── loading ├── __init__.py └── refer.py └── preprocessing.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | *.pt 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Ruff 133 | .ruff_cache/ 134 | 135 | # Auth Tokens / Hidden Files 136 | .hf_token 137 | .wandb_api_key 138 | .*_token 139 | .*api_key 140 | 141 | # IDE Caches 142 | .idea/ 143 | .vscode/ 144 | 145 | # Mac OS 146 | .DS_Store 147 | 148 | # Caches and Datasets 149 | cache/ 150 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Weizhi Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open-Qwen2VL 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2504.00595-df2a2a.svg?style=for-the-badge)](https://arxiv.org/abs/2504.00595) 4 | 5 | Official code repo for our work [Open-Qwen2VL: Compute-Efficient Pre-Training of Fully-Open Multimodal LLMs on Academic Resources](https://victorwz.github.io/Open-Qwen2VL/). 6 | 7 | ## Introduction 8 | This repo supports: 9 | - Data quality score generation with DFN/CLIP and [MLM-Filter](https://github.com/Victorwz/MLM_Filter) 10 | - High quality data selection based on the quality scores and resharding into webdataset format 11 | - Multimodal Sequence Packing towards large-scale image-text dataset in webdataset format (supporting both caption data and interleaved data) 12 | - Pre-training with packed multimodal sequences 13 | - Supversied fine-tuning on both small-scale SFT data like [LLaVA-665k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json) and large-scale SFT data like [MAmmoTH-VL-10M](https://huggingface.co/datasets/MAmmoTH-VL/MAmmoTH-VL-Instruct-12M) 14 | - Evaluation on a series of multimodal benchmarks 15 | 16 | 17 | ## Release 18 | - [3/31/2025] 🔥 We released all pre-trained model and instruction-tuned model checkpoints at [Open-Qwen2VL](https://huggingface.co/weizhiwang/Open-Qwen2VL) and [Open-Qwen2VL-Base](https://huggingface.co/weizhiwang/Open-Qwen2VL-Base) 19 | - [3/31/2025] 🔥 We released all pre-training data in webdataset format at [Open-Qwen2VL-Data](https://huggingface.co/datasets/weizhiwang/Open-Qwen2VL-Data). 20 | - [3/31/2025] 🔥 We released the technical report for [**Open-Qwen2VL**](https://arxiv.org/abs/2504.00595). 21 | 22 | ## Install 23 | 24 | ```Shell 25 | conda create -n openqwen2vl python=3.10 26 | conda activate openqwen2vl 27 | pip install -e prismatic-vlms 28 | ``` 29 | 30 | If you need to pre-train or SFT the MLLM, install flash-attention: 31 | ``` 32 | pip install flash-attn --no-build-isolation 33 | ``` 34 | 35 | ## Directly Use or Trial 36 | ```python 37 | import requests 38 | import torch 39 | from PIL import Image 40 | from prismatic import load 41 | 42 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 43 | 44 | # Load a pretrained VLM (either local path, or ID to auto-download from the HF Hub) 45 | vlm = load("Open-Qwen2VL") 46 | vlm.to(device, dtype=torch.bfloat16) 47 | 48 | # Download an image and specify a prompt 49 | image_url = "https://huggingface.co/adept/fuyu-8b/resolve/main/bus.png" 50 | image = [vlm.vision_backbone.image_transform(Image.open(requests.get(image_url, stream=True).raw).convert("RGB")).unsqueeze(0)] 51 | user_prompt = '\nDescribe the image." 52 | 53 | # Generate! 54 | generated_text = vlm.generate_batch( 55 | image, 56 | [user_prompt], 57 | do_sample=False, 58 | max_new_tokens=512, 59 | min_length=1, 60 | ) 61 | print(generated_text[0]) 62 | ``` 63 | 64 | ## Multimodal Sequence Packing 65 | We have released all our pre-training image-text caption data in webdataset format at [Open-Qwen2VL-Data](https://huggingface.co/datasets/weizhiwang/Open-Qwen2VL-Data). Please download it with ```huggingface-cli download``` or directly ```git clone```. 66 | 67 | Then please run 68 | ```shell 69 | bash mm_sequence_packing/multiprocess_sequence_packing_image_to_pil.sh 0 4 504 datacomp 70 | bash mm_sequence_packing/multiprocess_sequence_packing_image_to_pil.sh 0 4 326 ccs 71 | ``` 72 | 73 | [Update on 5/16/2025] 74 | We support sequence packing into a single json file which contains base64-string images and texts. Storing in json files can save 95% more disk space. You can run 75 | ```shell 76 | bash mm_sequence_packing/multiprocess_sequence_packing_image_to_json.sh 0 4 504 /path/to/datacomp_wds /path/to/datacomp_packed_files 77 | ``` 78 | 79 | ## Compute-Efficient MLLM Pre-Training 80 | Prior to training, please follow the sequence packing instructions in the README to prepare the pickle files for each subdataset. 81 | 82 | Then please run the training script 83 | ```Shell 84 | bash prismatic-vlms/train.sh ${CKPTID} ${STAGE} ${BSZ} ${PER_GPU_BSZ} 85 | ``` 86 | Here are the parameters for training: 87 | - `CKPTID`: id for the saved checkpoint; 88 | - `STAGE`: choose between `pretrain` and `full-pretrain`, in which the full-pretrain will make the vision encoder trainable as well; 89 | - `BSZ`: global batch size; 90 | - `PER_GPU_BSZ`: the batch size for each gpu. If the global_bsz != num_gpus * per_gpu_bsz, then the gradient accumulation will be applied. 91 | 92 | ## Visual SFT 93 | ### Large-Scale SFT on MammoTH-VL-10M 94 | Please firstly download and unzip the images of [MAmmoTH-VL-10M](https://huggingface.co/datasets/MAmmoTH-VL/MAmmoTH-VL-Instruct-12M). Then run ```python data_prepare/split_mammoth_10m.py``` to split each instruction examples into single json files. 95 | 96 | Please run the training script 97 | ```Shell 98 | bash prismatic-vlms/fine_tune_mammoth.sh ${CKPT_PATH} ${CKPTID} 99 | ``` 100 | Here are the parameters for training: 101 | - `CKPT_PATH`: the path to the pre-trained MLLM checkpoint after the pre-training stage; 102 | - `CKPTID`: id for the saved checkpoint 103 | 104 | ### Normal SFT Scripts 105 | 106 | Then please run the training script 107 | ```Shell 108 | bash prismatic-vlms/fine_tune.sh ${CKPT_PATH} ${CKPTID} ${DATAPATH} 109 | ``` 110 | Here are the parameters for training: 111 | - `CKPT_PATH`: the path to the pre-trained MLLM checkpoint after the pre-training stage; 112 | - `CKPTID`: id for the saved checkpoint; 113 | - `DATAPATH`: the path to the SFT dataset json file. 114 | 115 | ## Evaluations 116 | Please follow the doc [docs/Eval.md](docs/Eval.md) for evaluation instructions. 117 | 118 | ## Data Quality Score Generation and Data Filtering 119 | Please follow the doc [docs/DATA_Filter.md](docs/DATA_Filter.md) for data filtering instructions. 120 | 121 | ## Citation 122 | 123 | Please cite our paper if you find this repository interesting or helpful: 124 | ```bibtex 125 | @article{Open-Qwen2VL, 126 | title={Open-Qwen2VL: Compute-Efficient Pre-Training of Fully-Open Multimodal LLMs on Academic Resources}, 127 | author={Wang, Weizhi and Tian, Yu and Yang, Linjie and Wang, Heng and Yan, Xifeng}, 128 | journal={arXiv preprint arXiv:2504.00595}, 129 | year={2025} 130 | } 131 | ``` 132 | 133 | 134 | ## Credits 135 | Our codebase is developed based on [prismatic-vlms](https://github.com/TRI-ML/prismatic-vlms) and [vlm-evaluation](https://github.com/TRI-ML/vlm-evaluation). 136 | -------------------------------------------------------------------------------- /data_filtering/baselines.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from baselines.apply_filter import apply_filter 8 | 9 | BASELINES = { 10 | "no_filter", 11 | "basic_filter", 12 | "text_based", 13 | "image_based", 14 | "image_based_intersect_clip_score", 15 | "clip_score", 16 | "laion2b", 17 | "llava_image_text_matching_score", 18 | "llava_object_detail_fulfillment_score", 19 | "llava_caption_text_quality_score", 20 | "llava_semantic_understanding_score", 21 | "reward_quality_score", 22 | "dfn_clip_score" 23 | } 24 | 25 | ARCH = { 26 | "b32", 27 | "l14", 28 | } 29 | 30 | CLUSTER_CENTROID_SCALES = [ 31 | "small", 32 | "medium", 33 | "large", 34 | "xlarge", 35 | ] 36 | 37 | 38 | def check_args(args): 39 | if args.name not in BASELINES: 40 | raise ValueError(f"--name must be in: {BASELINES}") 41 | 42 | if args.name == "laion2b": 43 | if ( 44 | args.fraction is not None 45 | or args.threshold is not None 46 | or args.arch is not None 47 | or args.image_based_scale is not None 48 | ): 49 | raise ValueError("laion2b does not support clip_score or image_based flags") 50 | 51 | # clip_score checks 52 | if "clip_score" in args.name and "dfn" not in args.name: 53 | if args.fraction is None and args.threshold is None: 54 | raise ValueError( 55 | "--fraction or --threshold must be specified for clip_score baselines" 56 | ) 57 | if args.fraction is not None and args.threshold is not None: 58 | raise ValueError( 59 | "specify either --fraction or --threshold for clip_score baselines but not both" 60 | ) 61 | if args.arch is None: 62 | raise ValueError(f"specify architecture {ARCH}, for clip_score baselines") 63 | if args.fraction is not None and not ("score" in args.name): 64 | raise ValueError("--fraction value only used for clip_score baselines") 65 | if args.threshold is not None and not ("score" in args.name): 66 | raise ValueError("--threshold value only used for clip_score baselines") 67 | if args.arch is not None and not ("clip_score" in args.name): 68 | raise ValueError("--arch value only used for clip_score baselines") 69 | 70 | # image_based checks 71 | if args.image_based_scale is None and "image_based" in args.name: 72 | raise ValueError( 73 | "--image_based_scale value must be passed for image_based and image_based_intersect_clip_score_* baselines (for clustering)" 74 | ) 75 | if args.image_based_scale is not None and not ("image_based" in args.name): 76 | raise ValueError( 77 | "--image_based_scale should only be passed for image_based and image_based_intersect_clip_score_* baselines (for clustering)" 78 | ) 79 | if "image_based" in args.name and not torch.cuda.is_available(): 80 | raise ValueError( 81 | "gpus needed for image_based baselines, torch.cuda.is_available() must return true" 82 | ) 83 | 84 | npy_parent = Path(args.save_path).parent 85 | if not os.path.exists(npy_parent): 86 | print(f"creating: {npy_parent}") 87 | os.mkdir(npy_parent) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser( 92 | description="This is a command line script for reproducing the main DataComp filtering baselines. The output of the script is a numpy file (.npy) containing the uids in the filtered subsets in sorted binary format. Please see README.md for additional information" 93 | ) 94 | 95 | parser.add_argument( 96 | "--name", 97 | type=str, 98 | required=True, 99 | choices=list(BASELINES), 100 | help="name of the baseline", 101 | ) 102 | 103 | parser.add_argument( 104 | "--metadata_dir", 105 | type=str, 106 | required=True, 107 | help="directory (local or cloud) containing parquet, npz metadata", 108 | ) 109 | 110 | parser.add_argument( 111 | "--save_path", 112 | type=str, 113 | required=True, 114 | help="path to output .npy, note: cloudpaths are not supported for this arg", 115 | ) 116 | 117 | parser.add_argument( 118 | "--threshold", 119 | type=float, 120 | required=False, 121 | default=None, 122 | help="A threshold to apply on a CLIP score (e.g., a value of 0.25 will only keep examples with CLIP score over 0.25)", 123 | ) 124 | 125 | parser.add_argument( 126 | "--fraction", 127 | type=float, 128 | required=False, 129 | default=None, 130 | help="a fraction of metadata to keep according to CLIP score (e.g., a value of 0.25 will keep the top 25 percent of examples in the pool by CLIP score)", 131 | ) 132 | 133 | parser.add_argument( 134 | "--arch", 135 | type=str, 136 | required=False, 137 | choices=list(ARCH), 138 | help="kinds of features (b32 or l14) on which to run the CLIP score filter", 139 | ) 140 | 141 | parser.add_argument( 142 | "--num_workers", 143 | type=int, 144 | required=False, 145 | default=os.cpu_count(), 146 | help="number of workers, generally set to number of cpu cores. workers read their metadata files and filter them in parallel).", 147 | ) 148 | 149 | parser.add_argument( 150 | "--num_gpus", 151 | type=int, 152 | required=False, 153 | default=torch.cuda.device_count(), 154 | help="number of gpus for the image_based gpu implementation. num_gpus metadata files are processed in parallel on each gpu worker. NOTE: this parameter is ignored for non-image_basesd baselines", 155 | ) 156 | 157 | parser.add_argument( 158 | "--batch_size", 159 | type=int, 160 | required=False, 161 | default=1024, 162 | help="batch size for the image_based gpu implementation. NOTE: this parameter is ignored for non-image_basesd baselines", 163 | ) 164 | 165 | parser.add_argument( 166 | "--image_based_scale", 167 | type=str, 168 | required=False, 169 | choices=CLUSTER_CENTROID_SCALES, 170 | help="datacomp scale, used for the clutering baselines", 171 | default=None, 172 | ) 173 | 174 | args = parser.parse_args() 175 | 176 | # all error checking happens here and apply_filter assumes correct input 177 | check_args(args) 178 | 179 | # route the args to the correct baseline 180 | apply_filter(args) 181 | -------------------------------------------------------------------------------- /data_filtering/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/data_filtering/baselines/__init__.py -------------------------------------------------------------------------------- /data_filtering/baselines/image_based_clustering.md: -------------------------------------------------------------------------------- 1 | # Clustering 2 | 3 | Generates cluster centroids from the `image-based` baselines using k-means clustering. 4 | 5 | 6 | ## Installing dependencies 7 | 8 | We use `faiss-gpu` for k-means clustering. To install, run the following commands: 9 | 10 | ```bash 11 | conda install -c conda-forge faiss-gpu 12 | ``` 13 | 14 | Or check out [faiss-gpu](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) 15 | 16 | ## Run clustering 17 | 18 | To run clustering for the `small` pool, run the following command: 19 | 20 | 21 | ``` 22 | python image_based_clustering.py \ 23 | --metadata_dir path/to/metadata \ 24 | --save_path path/to/output/centroids \ 25 | --num_clusters 100000 \ 26 | --sample_ratio -1.0 \ 27 | --num_gpus 8 \ 28 | --num_workers 26 \ 29 | ``` 30 | 31 | Explanation to several arguments: 32 | 33 | - `sample_ratio`: the ratio of samples to use in clustering. In particular, we sample `sample_ratio` percent embeddings to do clustering due to memory constraint. We use 0.3 for `large` and 0.03 for `xlarge`. Default is -1.0 (no sampling) 34 | - `disable_caption_filtering`: whether to disable caption filtering to the dataset. Default is `False` 35 | 36 | On a machine with 8 GPUs and 26 CPUs (there are 26 parquet files for the `small` pool), the clustering process takes about 10 minutes. 37 | -------------------------------------------------------------------------------- /data_filtering/baselines/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import random 4 | import urllib 5 | import warnings 6 | from multiprocessing import Pool 7 | from typing import Any, List 8 | 9 | import numpy as np 10 | import torch 11 | from tqdm import tqdm 12 | 13 | 14 | def random_seed(seed: int = 0) -> None: 15 | """set seed 16 | 17 | Args: 18 | seed (int, optional): seed value. Defaults to 0. 19 | """ 20 | torch.manual_seed(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | 24 | 25 | def download(name: str, root: str = None) -> str: 26 | """dowload assets necessary for running baselines (e.g., model weights) 27 | 28 | Args: 29 | name (str): name key of the asset to download 30 | root (str, optional): cache location to download the asset. Defaults to None. 31 | 32 | Raises: 33 | ValueError: unsupported name 34 | RuntimeError: file exists but it is not a normal file 35 | RuntimeError: file exists but has the incorrect sha256 checksum 36 | 37 | Returns: 38 | str: _description_ 39 | """ 40 | # modified from oai _download clip function 41 | 42 | if root is None: 43 | root = os.path.expanduser("~/.cache/datacomp") 44 | else: 45 | root = os.path.expanduser(root) 46 | 47 | # paths for checkpoints we may need to download for baselines along with their sha256 hashes 48 | cloud_checkpoints = { 49 | "imagenet21k_wordnet_ids": { 50 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/imagenet21k_wordnet_ids.txt", 51 | "sha256": "66362bdedf36d933382edca5493fc562dcc17128ce36403c9e730a75f48cb2f2", 52 | }, 53 | "in1k_clip_vit_l14_0": { 54 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/in1k_clip_vit_l14_0.pt", 55 | "sha256": "304990fd492f40ba90072b80af78d6e8edcab9d042476ef513e03441cf53743b", 56 | }, 57 | "in1k_clip_vit_l14_1": { 58 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/in1k_clip_vit_l14_1.pt", 59 | "sha256": "72f008e6aa3bfb54174fa322963215337f843fc90fd642a9dccc815dd9e7ee76", 60 | }, 61 | "in1k_clip_vit_l14_2": { 62 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/in1k_clip_vit_l14_2.pt", 63 | "sha256": "83c3cc6503b8a50a22cac7023ca6cd61b130ab9725db2a7519f1bec5cabf5493", 64 | }, 65 | "in1k_clip_vit_l14_3": { 66 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/in1k_clip_vit_l14_3.pt", 67 | "sha256": "83faa289f847799ddcd438e5b74df1cb5973c345674896d10ba31ea48bc63804", 68 | }, 69 | "in1k_clip_vit_l14_4": { 70 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/in1k_clip_vit_l14_4.pt", 71 | "sha256": "067a45f4a140a93371876510a2fd9198a89ac1d194ae4063a51b45a6da69cadf", 72 | }, 73 | "large_centroids": { 74 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/large_centroids_l14.pt", 75 | "sha256": "04eeab1069d3540c246cf7ce69323147351a474017fc472e4c50a018ca32240b", 76 | }, 77 | "medium_centroids": { 78 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/medium_centroids_l14.pt", 79 | "sha256": "028b1721b1d0f139c565b6b0ac99f8a1756f4bae89c36b0ec6d1c6ea9b6f112d", 80 | }, 81 | "small_centroids": { 82 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/small_centroids_l14.pt", 83 | "sha256": "23c66a05e49ad77283c1e2b33355c7eb088ac332a944c97ff85d5dfd48a5b251", 84 | }, 85 | "xlarge_centroids": { 86 | "url": "https://github.com/sagadre/datacomp_baseline_assets/releases/download/v0.1.0-alpha/xlarge_centroids_l14.pt", 87 | "sha256": "3f62e5f8ae3a715ce84e422846fcfce1536d184455ea234790c4a6465c4c6726", 88 | }, 89 | "fasttext": { 90 | "url": "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin", 91 | "sha256": "7e69ec5451bc261cc7844e49e4792a85d7f09c06789ec800fc4a44aec362764e", 92 | }, 93 | "mmmu": { 94 | "url": "https://huggingface.co/datasets/weizhiwang/datacomp-hq/resolve/main/mmmu_embeddings.pt", 95 | "sha256": "392a4023d6eb739dbba537bb0230a33c4f75f4eaee7ece7f6406554fbb8530f6", 96 | }, 97 | "mscoco": { 98 | "url": "https://huggingface.co/datasets/weizhiwang/datacomp-hq/resolve/main/mscoco_embeddings.pt", 99 | "sha256": "e4a235d15ee3fac84c9d686b5f985de2948e26c95f4b73e91523d417d81b93e9", 100 | } 101 | } 102 | 103 | if name not in cloud_checkpoints: 104 | raise ValueError( 105 | f"unsupported cloud checkpoint: {name}. currently we only support: {list(cloud_checkpoints.keys())}" 106 | ) 107 | 108 | os.makedirs(root, exist_ok=True) 109 | 110 | expected_sha256 = cloud_checkpoints[name]["sha256"] 111 | download_target = None 112 | if name == "fasttext": 113 | download_target = os.path.join(root, "lid.176.bin") 114 | else: 115 | download_target = os.path.join(root, f"{name}.pt") 116 | url = cloud_checkpoints[name]["url"] 117 | 118 | if os.path.exists(download_target) and not os.path.isfile(download_target): 119 | raise RuntimeError(f"{download_target} exists and is not a regular file") 120 | 121 | if os.path.isfile(download_target): 122 | if ( 123 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 124 | == expected_sha256 125 | ): 126 | return download_target 127 | else: 128 | warnings.warn( 129 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 130 | ) 131 | 132 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 133 | print(f"downloading {url} to {download_target}") 134 | with tqdm( 135 | total=int(source.info().get("Content-Length")), 136 | ncols=80, 137 | unit="iB", 138 | unit_scale=True, 139 | unit_divisor=1024, 140 | ) as loop: 141 | while True: 142 | buffer = source.read(8192) 143 | if not buffer: 144 | break 145 | 146 | output.write(buffer) 147 | loop.update(len(buffer)) 148 | 149 | if ( 150 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 151 | != expected_sha256 152 | ): 153 | raise RuntimeError( 154 | "Model has been downloaded but the SHA256 checksum does not not match" 155 | ) 156 | 157 | return download_target 158 | 159 | 160 | def worker_threadpool( 161 | worker_fn: Any, concat_fn: Any, paths: List[str], n_workers: int 162 | ) -> np.ndarray: 163 | """get filtered uids 164 | 165 | Args: 166 | worker_fn (Any): function to map over the pool 167 | concat_fn (Any): function to use to collate the results 168 | paths (List[str]): metadata paths to process 169 | n_workers (int): number of cpu workers 170 | 171 | Returns: 172 | np.ndarray: filtered uids 173 | """ 174 | print("creating thread pool for processing") 175 | with Pool(n_workers) as pool: 176 | uids = [] 177 | for u in tqdm( 178 | pool.imap_unordered(worker_fn, paths), 179 | total=len(paths), 180 | ): 181 | uids.append(u) 182 | 183 | return concat_fn(uids) 184 | -------------------------------------------------------------------------------- /data_filtering/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.3.0 2 | aiohttp==3.8.3 3 | aioitertools==0.11.0 4 | aiosignal==1.3.1 5 | albumentations==1.3.0 6 | astunparse==1.6.3 7 | async-timeout==4.0.2 8 | attrs==22.1.0 9 | azure-core==1.26.1 10 | azure-storage-blob==12.14.1 11 | beautifulsoup4==4.11.1 12 | braceexpand==0.1.7 13 | cachetools==5.2.0 14 | cffi==1.15.1 15 | charset-normalizer==2.1.1 16 | click==8.1.3 17 | clip-benchmark==1.4.0 18 | cloudpathlib==0.13.0 19 | colorama==0.4.4 20 | contourpy==1.0.7 21 | cryptography==39.0.2 22 | cycler==0.11.0 23 | dill==0.3.6 24 | dm-tree==0.1.7 25 | docker-pycreds==0.4.0 26 | docutils==0.16 27 | etils==0.9.0 28 | exifread-nocycle==3.0.1 29 | fasttext==0.9.3 30 | fasttext-langdetect==1.0.3 31 | filelock==3.8.0 32 | fire==0.4.0 33 | flatbuffers==22.10.26 34 | fonttools==4.39.2 35 | frozenlist==1.3.3 36 | fsspec==2024.10.0 37 | ftfy==6.1.1 38 | gast==0.4.0 39 | # gcld3==3.0.13 40 | gdown==4.5.3 41 | gitdb==4.0.9 42 | gitpython==3.1.29 43 | google-api-core==2.11.0 44 | google-auth==2.14.1 45 | google-auth-oauthlib==0.4.6 46 | google-cloud-core==2.3.2 47 | google-cloud-storage==2.7.0 48 | google-crc32c==1.5.0 49 | google-pasta==0.2.0 50 | google-resumable-media==2.4.0 51 | googleapis-common-protos==1.57.0 52 | grpcio==1.50.0 53 | h5py==3.7.0 54 | idna==3.4 55 | imageio==2.22.4 56 | img2dataset 57 | importlib-resources==5.10.0 58 | isodate==0.6.1 59 | jmespath==1.0.1 60 | joblib==1.2.0 61 | kaleido==0.2.1 62 | keras==2.11.0 63 | keras-preprocessing==1.1.2 64 | kiwisolver==1.4.4 65 | libclang==14.0.6 66 | littleutils==0.2.2 67 | markdown==3.4.1 68 | markupsafe==2.1.1 69 | matplotlib==3.7.1 70 | mock==4.0.3 71 | msrest==0.7.1 72 | multidict==6.0.3 73 | multiprocess==0.70.14 74 | networkx==2.8.8 75 | nltk==3.8.1 76 | numpy==1.26.4 77 | nvidia-cublas-cu11==11.10.3.66 78 | nvidia-cuda-nvrtc-cu11==11.7.99 79 | nvidia-cuda-runtime-cu11==11.7.99 80 | nvidia-cudnn-cu11==8.5.0.96 81 | oauthlib==3.2.2 82 | ogb==1.3.5 83 | open-clip-torch==2.16.1 84 | opencv-python==4.6.0.66 85 | opencv-python-headless==4.6.0.66 86 | opt-einsum==3.3.0 87 | outdated==0.2.2 88 | packaging==21.3 89 | pandas==1.5.1 90 | pathtools==0.1.2 91 | patsy==0.5.3 92 | pillow==9.3.0 93 | plotly==5.13.1 94 | promise==2.3 95 | protobuf==3.20.2 96 | psutil==5.9.4 97 | pyarrow==7.0.0 98 | pyasn1==0.4.8 99 | pyasn1-modules==0.2.8 100 | pybind11==2.10.1 101 | # pycld3==0.22 102 | pycparser==2.21 103 | pyparsing==3.0.9 104 | pysimdjson==5.0.2 105 | pysocks==1.7.1 106 | python-dateutil==2.8.2 107 | pytz==2022.6 108 | pywavelets==1.4.1 109 | pyyaml 110 | qudida==0.0.4 111 | regex==2022.10.31 112 | requests==2.28.1 113 | requests-oauthlib==1.3.1 114 | responses==0.18.0 115 | rsa==4.7.2 116 | scikit-image==0.19.3 117 | scikit-learn==1.1.3 118 | scipy==1.9.3 119 | sentencepiece==0.1.97 120 | sentry-sdk==1.11.1 121 | setproctitle==1.3.2 122 | shortuuid==1.0.11 123 | six==1.16.0 124 | smmap==5.0.0 125 | soupsieve==2.3.2.post1 126 | statsmodels==0.13.5 127 | tenacity==8.2.2 128 | termcolor==2.1.0 129 | threadpoolctl==3.1.0 130 | tifffile==2022.10.10 131 | timm==0.6.11 132 | toml==0.10.2 133 | typeguard==2.13.3 134 | typing-extensions==4.4.0 135 | urllib3==1.26.12 136 | wandb==0.12.21 137 | wcwidth==0.2.5 138 | webdataset==0.2.31 139 | werkzeug==2.2.2 140 | wget==3.2 141 | wilds==2.0.0 142 | wrapt==1.14.1 143 | xxhash==3.2.0 144 | yarl==1.8.2 145 | zipp==3.10.0 146 | -------------------------------------------------------------------------------- /data_prepare/split_mammoth_10m.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from transformers import AutoTokenizer 4 | import jsonlines 5 | import os 6 | 7 | n_patches = 729 8 | 9 | data = json.load(open("data/MAmmoTH-VL-Instruct-12M/mammoth_si_10M.json")) 10 | 11 | with jsonlines.open("data/MAmmoTH-VL-Instruct-12M/mammoth_si_10M_simple.jsonl", 'w') as jsonlw: 12 | for sample in tqdm(data): 13 | simple_sample = {} 14 | if "image" in sample: 15 | path = sample["image"].replace("/", "_").replace(".", "_")[:200] + ".json" 16 | else: 17 | path = (str(sample["source"]) + str(sample["id"])).replace(".", "_") + ".json" 18 | if os.path.exists(f"data/MAmmoTH-VL-Instruct-12M/instruction_si_split/{path}"): 19 | # print(path) 20 | continue 21 | if not path.endswith(".json"): 22 | print(path) 23 | continue 24 | n_words = n_patches + sum([len(turn["value"].replace("", "").split()) for turn in sample["conversations"]]) 25 | 26 | simple_sample["path"] = path 27 | simple_sample["length"] = n_words 28 | 29 | jsonlw.write(simple_sample) 30 | 31 | with open(f"data/MAmmoTH-VL-Instruct-12M/instruction_si_split/{path}", "w") as f: 32 | f.write(json.dumps(sample)) -------------------------------------------------------------------------------- /docs/DATA_Filter.md: -------------------------------------------------------------------------------- 1 | # Data Filtering for Open-Qwen2VL Pre-Training Data 2 | ## Install 3 | We develop our data filtering code based on DataComp repo. Please firstly install the required packages: 4 | ```shell 5 | cd data_filtering 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | ## Data Downloading 10 | ### CC3M-CC12M-SBU 11 | ```shell 12 | wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json 13 | img2dataset --url_list ccs_filtered.json --input_format "json" \ 14 | --url_col "url" --caption_col "caption" --output_format webdataset \ 15 | --output_folder ccs_webdataset --processes_count 32 --thread_count 128 --image_size 512 \ 16 | --resize_only_if_bigger=True --resize_mode="keep_ratio" --skip_reencode=True \ 17 | --enable_wandb False 18 | ``` 19 | 20 | ### DataComp-Medium-128M 21 | Please follow the [offical scripts](https://github.com/mlfoundations/datacomp) to download it into webdataset format. 22 | 23 | 24 | ## Data Quality Score Generation 25 | ### DFN-CLIP 26 | Since we do not have the access to DFN model checkpoint, we directly use [the released uids of selected high-quality subset from DFN](https://huggingface.co/datasets/apf1/datafilteringnetworks_2b/resolve/main/datacomp_medium_dfn_20m_inds.npy). 27 | 28 | 29 | ### MLM-Filter 30 | MLM-Filter adopts an efficient MLLM to generate four distince and comprehensive metrics to assess the quality of each image-text caption data sample. Please follow the official repo to perform the large scale quality score generation using [mlm-filter-qwen2.5-1.5b-gpt4o](https://huggingface.co/weizhiwang/mlm-filter-qwen2.5-1.5b-gpt4o). 31 | 32 | ## Data Filtering for MLM-Filter 33 | ```shell 34 | INPUT_DATADIR="path/to/datacomp/shards" 35 | python baselines.py --metadata_dir $INPUT_DATADIR --save_path medium_filter_results/medium_semantic_understanding_th_85.npy --name llava_semantic_understanding_score --threshold 85 36 | mkdir "path/to/datacomp/medium_semantic_understanding_th_85" 37 | python resharder.py -i $DOWNLOAD_DIR -o "path/to/datacomp/medium_semantic_understanding_th_85" -s medium_filter_results/medium_semantic_understanding_th_85.npy 38 | ``` -------------------------------------------------------------------------------- /docs/Eval.md: -------------------------------------------------------------------------------- 1 | # MLLM Evaluation 2 | 3 | ## Install 4 | Please continue the installation within the previously installed [openqwen2vl](../README.md##Install) environment and ensure the modified `prismatic-vlms` repo has been installed: 5 | ```Shell 6 | conda activate openqwen2vl 7 | pip install -e vlm-evaluation 8 | ``` 9 | 10 | ## Benchmark Evaluation for Pre-Trained or Fine-Tuned MLLMs 11 | Please run 12 | ```Shell 13 | bash eval.sh ${model_dir} 14 | ``` 15 | - `model_dir`: the directory of the pre-trained or fine-tuned MLLM -------------------------------------------------------------------------------- /mm_sequence_packing/multiprocess_sequence_packing_image_to_json.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | num_gpu=$2 3 | gpu_start_id=$1 4 | 5 | gpu_name=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader) 6 | 7 | start=`date +"%s"` 8 | 9 | mkdir $5 10 | 11 | i=0 12 | while [ $i -lt $num_gpu ]; do 13 | { 14 | echo $i 15 | 16 | python mm_sequence_packing/sequence_packing_image_to_json.py --tar-file-path $4 --save-path $5 --gpu-id $(($i + $gpu_start_id)) --tars-per-gpu $3 17 | } & 18 | i=$(($i + 1)) 19 | done 20 | 21 | wait 22 | end=`date +"%s"` 23 | echo "time: " `expr $end - $start` 24 | -------------------------------------------------------------------------------- /mm_sequence_packing/multiprocess_sequence_packing_image_to_pil.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | start_id=$1 3 | num_proc=$2 # on our server we use 4; you can enlarge this based on you cpu cores 4 | tars_per_gpu=$3 # the number of tar files processed by each process 5 | 6 | 7 | gpu_name=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader) 8 | 9 | start=`date +"%s"` 10 | 11 | i=0 12 | while [ $i -lt $num_proc ]; do 13 | { 14 | echo $i 15 | if [ "$4" == "obelics" ]; then 16 | mkdir Open-Qwen2VL-Data/obelics/obelics_single_pkl_pil 17 | python mm_sequence_packing/sequence_packing_image_to_pil.py --tar-file-path Open-Qwen2VL-Data/obelics_webdataset --save-path Open-Qwen2VL-Data/obelics_single_pkl_pil --gpu-id $(($i + $start_id)) --tars-per-gpu ${tars_per_gpu} 18 | elif [ "$4" == "synthdog" ]; then 19 | mkdir Open-Qwen2VL-Data/synthdog-en/synthdog_single_pkl_pil 20 | python mm_sequence_packing/sequence_packing_image_to_pil.py --tar-file-path Open-Qwen2VL-Data/synthdog_webdataset --save-path Open-Qwen2VL-Data/synthdog_single_pkl_pil --gpu-id $(($i + $start_id)) --tars-per-gpu ${tars_per_gpu} 21 | elif [ "$4" == "ccs" ]; then 22 | python mm_sequence_packing/sequence_packing_image_to_pil.py --tar-file-path Open-Qwen2VL-Data/ccs_webdataset --save-path Open-Qwen2VL-Data/ccs_single_pkl_pil --gpu-id $(($i + $start_id)) --tars-per-gpu ${tars_per_gpu} 23 | elif [ "$4" == "laion" ]; then 24 | python mm_sequence_packing/sequence_packing_image_to_pil.py --tar-file-path Open-Qwen2VL-Data/laion_webdataset --save-path Open-Qwen2VL-Data/laion_single_pkl_pil --gpu-id $(($i + $start_id)) --tars-per-gpu ${tars_per_gpu} 25 | elif [ "$4" == "datacomp-dfn" ]; then 26 | python mm_sequence_packing/sequence_packing_image_to_pil.py --tar-file-path Open-Qwen2VL-Data/datacomp_medium_dfn_webdataset --save-path Open-Qwen2VL-Data/datacomp_dfn_single_pkl_pil --gpu-id $(($i + $start_id)) --tars-per-gpu ${tars_per_gpu} 27 | else 28 | python mm_sequence_packing/sequence_packing_image_to_pil.py --tar-file-path Open-Qwen2VL-Data/datacomp_medium_mlm_filter_su_85_union_dfn_webdataset/ --save-path Open-Qwen2VL-Data/datacomp_hq_single_pkl_pil --gpu-id $(($i + $start_id)) --tars-per-gpu ${tars_per_gpu} 29 | fi 30 | } & 31 | i=$(($i + 1)) 32 | done 33 | 34 | wait 35 | end=`date +"%s"` 36 | echo "time: " `expr $end - $start` 37 | -------------------------------------------------------------------------------- /prismatic-vlms/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ruff 132 | .ruff_cache/ 133 | 134 | # Auth Tokens / Hidden Files 135 | .hf_token 136 | .wandb_api_key 137 | .*_token 138 | .*api_key 139 | 140 | # IDE Caches 141 | .idea/ 142 | .vscode/ 143 | 144 | # Mac OS 145 | .DS_Store 146 | 147 | # Caches and Datasets 148 | cache/ 149 | data/ -------------------------------------------------------------------------------- /prismatic-vlms/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | 5 | repos: 6 | - repo: https://github.com/astral-sh/ruff-pre-commit 7 | rev: v0.2.2 8 | hooks: 9 | - id: ruff 10 | args: [ --fix, --exit-non-zero-on-fix ] 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 24.2.0 14 | hooks: 15 | - id: black 16 | 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.5.0 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-ast 22 | - id: check-case-conflict 23 | - id: check-merge-conflict 24 | - id: check-toml 25 | - id: check-yaml 26 | - id: end-of-file-fixer 27 | - id: trailing-whitespace -------------------------------------------------------------------------------- /prismatic-vlms/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Siddharth Karamcheti, Suraj Nair, Ashwin Balakrishna and Toyota Research Institute. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /prismatic-vlms/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help clean check autoformat 2 | .DEFAULT: help 3 | 4 | # Generates a useful overview/help message for various make features - add to this as necessary! 5 | help: 6 | @echo "make clean" 7 | @echo " Remove all temporary pyc/pycache files" 8 | @echo "make check" 9 | @echo " Run code style and linting (black, ruff) *without* changing files!" 10 | @echo "make autoformat" 11 | @echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this." 12 | 13 | clean: 14 | find . -name "*.pyc" | xargs rm -f && \ 15 | find . -name "__pycache__" | xargs rm -rf 16 | 17 | check: 18 | black --check . 19 | ruff check --show-source . 20 | 21 | autoformat: 22 | black . 23 | ruff check --fix --show-fixes . 24 | -------------------------------------------------------------------------------- /prismatic-vlms/fine_tune.sh: -------------------------------------------------------------------------------- 1 | CKPT_PATH=$1 2 | CKPTID=$2 3 | DATAPATH=$3 4 | 5 | # mount_path: you can change it to the path where you pre-download the pre-trained LLMs, if you will to download the model directly from HF hub, please use the holder name of the model, i.e. Qwen for Qwen-series models. Then the mount_path will be concatenated with model_id 6 | 7 | # trackers: you can change to ["jsonl",] ["wandb",] ["jsonl","wandb"] to decide whether to visualize your training on wandb 8 | 9 | 10 | # Run from the root of the repository 11 | torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/pretrain.py \ 12 | --stage "finetune" \ 13 | --model.type "one-stage+7b" \ 14 | --model.model_id qwen2.5-1.5b-instruct-continue-training-${CKPTID} \ 15 | --model.arch_specifier full-align+729-avgpool \ 16 | --model.vision_backbone_id "siglip-vit-so400m-384px" \ 17 | --model.image_resize_strategy "resize-naive" \ 18 | --model.llm_backbone_id qwen2.5-1.5b-instruct \ 19 | --model.finetune_global_batch_size 128 \ 20 | --model.finetune_per_device_batch_size 2 \ 21 | --mount_path Qwen \ 22 | --run_root_dir checkpoints \ 23 | --dataset.type "llava-v15" \ 24 | --pretrained_checkpoint ${CKPT_PATH}/checkpoints/latest-checkpoint.pt \ 25 | --dataset.finetune_stage_components=["${DATAPATH}","data/llava/images/"] -------------------------------------------------------------------------------- /prismatic-vlms/fine_tune_mammoth.sh: -------------------------------------------------------------------------------- 1 | CKPT_PATH=$1 2 | CKPTID=$2 3 | DATAPATH="data/MAmmoTH-VL-Instruct-12M/mammoth_si_10M_simple.jsonl" 4 | 5 | # mount_path: you can change it to the path where you pre-download the pre-trained LLMs, if you will to download the model directly from HF hub, please use the holder name of the model, i.e. Qwen for Qwen-series models. Then the mount_path will be concatenated with model_id 6 | 7 | # trackers: you can change to ["jsonl",] ["wandb",] ["jsonl","wandb"] to decide whether to visualize your training on wandb 8 | 9 | # Run from the root of the repository 10 | torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/pretrain.py \ 11 | --stage "large-finetune" \ 12 | --model.type "one-stage+7b" \ 13 | --model.model_id qwen2.5-1.5b-instruct-continue-training-${CKPTID} \ 14 | --model.arch_specifier full-align+729-avgpool \ 15 | --model.vision_backbone_id "siglip-vit-so400m-384px" \ 16 | --model.image_resize_strategy "resize-naive" \ 17 | --model.llm_backbone_id qwen2.5-1.5b-instruct \ 18 | --model.finetune_global_batch_size 128 \ 19 | --model.finetune_per_device_batch_size 2 \ 20 | --model.finetune_max_steps 150000 \ 21 | --mount_path Qwen \ 22 | --run_root_dir checkpoints \ 23 | --dataset.type "llava-v15" \ 24 | --pretrained_checkpoint ${CKPT_PATH}/checkpoints/latest-checkpoint.pt \ 25 | --dataset.finetune_stage_components=["${DATAPATH}","data/MAmmoTH-VL-Instruct-12M/single_image_data"] -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import available_model_ids, available_model_ids_and_names, get_model_description, load -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetConfig, DatasetRegistry 2 | from .models import ModelConfig, ModelRegistry 3 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/conf/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | datasets.py 3 | 4 | Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant 5 | and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes: 6 | - Dataset Variant (Identifier) --> e.g., "llava-v15" 7 | - Align Stage Dataset Components (annotations, images) 8 | - Finetune Stage Dataset Components (annotations, images) 9 | - Dataset Root Directory (Path) 10 | """ 11 | from dataclasses import dataclass 12 | from enum import Enum, unique 13 | from pathlib import Path 14 | from typing import Tuple 15 | 16 | from draccus import ChoiceRegistry 17 | 18 | 19 | @dataclass 20 | class DatasetConfig(ChoiceRegistry): 21 | # fmt: off 22 | dataset_id: str # Unique ID that fully specifies a dataset variant 23 | 24 | # Dataset Components for each Stage in < align | finetune > 25 | align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage 26 | finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage 27 | 28 | dataset_root_dir: str # Path to dataset root directory; others paths are relative to root 29 | # fmt: on 30 | 31 | 32 | # [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models) 33 | @dataclass 34 | class LLaVa_V15_Config(DatasetConfig): 35 | dataset_id: str = "llava-v15" 36 | 37 | align_stage_components: Tuple[Path, Path] = ( 38 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 39 | Path("download/llava-laion-cc-sbu-558k/"), 40 | ) 41 | finetune_stage_components: Tuple[Path, Path] = ( 42 | Path("data/llava/llava_phi_3_non_test_sft_data_516k.json"), 43 | Path("wzwang/data/llava"), 44 | ) 45 | dataset_root_dir: str = "data" 46 | 47 | train_num_samples: int = 200000 48 | dataset_resampled: bool = True 49 | min_num_images: int = 1 50 | max_num_images: int = 6 51 | workers: int = 4 52 | 53 | # [PreTrain] 54 | @dataclass 55 | class OBELICS_PreTrain_Config(DatasetConfig): 56 | dataset_id: str = "pretrain" 57 | 58 | align_stage_components: Tuple[Path, Path] = ( 59 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 60 | Path("download/llava-laion-cc-sbu-558k/"), 61 | ) 62 | finetune_stage_components: Tuple[Path, Path] = ( 63 | Path("data/llava/llava_v1_5_mix665k.json"), 64 | Path("data/llava/data"), 65 | ) 66 | dataset_root_dir: str = "data" 67 | 68 | train_num_samples: int = 3000000 69 | dataset_resampled: bool = True 70 | min_num_images: int = 1 71 | max_num_images: int = 6 72 | workers: int = 4 73 | 74 | # [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training) 75 | @dataclass 76 | class LLaVa_Multimodal_Only_Config(DatasetConfig): 77 | dataset_id: str = "llava-multimodal" 78 | 79 | align_stage_components: Tuple[Path, Path] = ( 80 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 81 | Path("download/llava-laion-cc-sbu-558k/"), 82 | ) 83 | finetune_stage_components: Tuple[Path, Path] = ( 84 | Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"), 85 | Path("download/llava-v1.5-instruct/"), 86 | ) 87 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 88 | 89 | 90 | # LLaVa-v15 + LVIS-Instruct-4V 91 | @dataclass 92 | class LLaVa_LVIS4V_Config(DatasetConfig): 93 | dataset_id: str = "llava-lvis4v" 94 | 95 | align_stage_components: Tuple[Path, Path] = ( 96 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 97 | Path("download/llava-laion-cc-sbu-558k/"), 98 | ) 99 | finetune_stage_components: Tuple[Path, Path] = ( 100 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"), 101 | Path("download/llava-v1.5-instruct/"), 102 | ) 103 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 104 | 105 | 106 | # LLaVa-v15 + LRV-Instruct 107 | @dataclass 108 | class LLaVa_LRV_Config(DatasetConfig): 109 | dataset_id: str = "llava-lrv" 110 | 111 | align_stage_components: Tuple[Path, Path] = ( 112 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 113 | Path("download/llava-laion-cc-sbu-558k/"), 114 | ) 115 | finetune_stage_components: Tuple[Path, Path] = ( 116 | Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"), 117 | Path("download/llava-v1.5-instruct/"), 118 | ) 119 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 120 | 121 | 122 | # LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct 123 | @dataclass 124 | class LLaVa_LVIS4V_LRV_Config(DatasetConfig): 125 | dataset_id: str = "llava-lvis4v-lrv" 126 | 127 | align_stage_components: Tuple[Path, Path] = ( 128 | Path("download/llava-laion-cc-sbu-558k/chat.json"), 129 | Path("download/llava-laion-cc-sbu-558k/"), 130 | ) 131 | finetune_stage_components: Tuple[Path, Path] = ( 132 | Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"), 133 | Path("download/llava-v1.5-instruct/"), 134 | ) 135 | dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") 136 | 137 | 138 | # === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! === 139 | @unique 140 | class DatasetRegistry(Enum): 141 | # === LLaVa v1.5 === 142 | OBELICS_PreTrain = OBELICS_PreTrain_Config 143 | LLAVA_V15 = LLaVa_V15_Config 144 | 145 | LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config 146 | 147 | LLAVA_LVIS4V = LLaVa_LVIS4V_Config 148 | LLAVA_LRV = LLaVa_LRV_Config 149 | 150 | LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config 151 | 152 | @property 153 | def dataset_id(self) -> str: 154 | return self.value.dataset_id 155 | 156 | 157 | # Register Datasets in Choice Registry 158 | for dataset_variant in DatasetRegistry: 159 | DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value) 160 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .load import available_model_ids, available_model_ids_and_names, get_model_description, load 2 | from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm 3 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/prismatic-vlms/prismatic/models/backbones/__init__.py -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_llm import LLMBackbone 2 | from .llama2 import LLaMa2LLMBackbone 3 | from .mistral import MistralLLMBackbone 4 | from .phi3 import Phi3LLMBackbone 5 | from .qwen2 import Qwen2LLMBackbone 6 | from .llama3 import LLaMa3LLMBackbone -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/llama2.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2.py 3 | 4 | Class definition for all LLMs derived from LlamaForCausalLM. 5 | """ 6 | from typing import Optional, Type 7 | 8 | import torch 9 | from torch import nn as nn 10 | from transformers import LlamaForCausalLM 11 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 12 | 13 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 14 | from prismatic.models.backbones.llm.prompting import ( 15 | LLaMa2ChatPromptBuilder, 16 | PromptBuilder, 17 | PurePromptBuilder, 18 | VicunaV15ChatPromptBuilder, 19 | ) 20 | 21 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 22 | # fmt: off 23 | LLAMA2_MODELS = { 24 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 25 | "llama2-7b-pure": { 26 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Llama-2-7b-hf" 27 | }, 28 | 29 | "llama2-13b-pure": { 30 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Llama-2-13b-hf" 31 | }, 32 | 33 | "llama3-8b-pure": { 34 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Meta-Llama-3-8B" 35 | }, 36 | 37 | # === Meta LLaMa-2 Chat Models === 38 | "llama2-7b-chat": { 39 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" 40 | }, 41 | 42 | "llama2-13b-chat": { 43 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" 44 | }, 45 | 46 | # === Vicuna v1.5 Chat Models === 47 | "vicuna-v15-7b": { 48 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "vicuna-7b-v1.5" 49 | }, 50 | 51 | "vicuna-v15-13b": { 52 | "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "vicuna-13b-v1.5" 53 | }, 54 | } 55 | # fmt: on 56 | 57 | 58 | class LLaMa2LLMBackbone(HFCausalLLMBackbone): 59 | def __init__( 60 | self, 61 | llm_backbone_id: str, 62 | llm_max_length: int = 4096, 63 | mount_path: Optional[str] = None, 64 | inference_mode: bool = False, 65 | use_flash_attention_2: bool = True, 66 | ) -> None: 67 | super().__init__( 68 | llm_backbone_id, 69 | llm_max_length=llm_max_length, 70 | mount_path=mount_path, 71 | inference_mode=inference_mode, 72 | use_flash_attention_2=use_flash_attention_2, 73 | **LLAMA2_MODELS[llm_backbone_id], 74 | ) 75 | 76 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 77 | # Weizhi for new project: we did not need the added padding token 78 | # self.tokenizer.add_special_tokens({"pad_token": ""}) 79 | # self.llm.config.pad_token_id = self.tokenizer.pad_token_id 80 | # self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 81 | 82 | @property 83 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 84 | if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): 85 | return PurePromptBuilder 86 | 87 | elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): 88 | return LLaMa2ChatPromptBuilder 89 | 90 | elif self.identifier.startswith("llama3-") and self.identifier.endswith("-pure"): 91 | return PurePromptBuilder 92 | 93 | elif self.identifier.startswith("vicuna"): 94 | return VicunaV15ChatPromptBuilder 95 | 96 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 97 | 98 | @property 99 | def transformer_layer_cls(self) -> Type[nn.Module]: 100 | return LlamaDecoderLayer 101 | 102 | @property 103 | def half_precision_dtype(self) -> torch.dtype: 104 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" 105 | return torch.bfloat16 106 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/llama3.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2.py 3 | 4 | Class definition for all LLMs derived from LlamaForCausalLM. 5 | """ 6 | from typing import Optional, Type 7 | 8 | import torch 9 | from torch import nn as nn 10 | from transformers import LlamaForCausalLM 11 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 12 | 13 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 14 | from prismatic.models.backbones.llm.prompting import PromptBuilder, LLaMa3ChatPromptBuilder 15 | 16 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 17 | # fmt: off 18 | LLAMA3_MODELS = { 19 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 20 | "llama3.1-8b-pure": { 21 | "llm_family": "llama3", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Llama-3.1-8B" 22 | }, 23 | 24 | "llama3.1-8b-instruct": { 25 | "llm_family": "llama3", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Llama-3.1-8B-Instruct" 26 | }, 27 | 28 | "llama3.2-3b-pure": { 29 | "llm_family": "llama3", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Llama-3.2-3B" 30 | }, 31 | 32 | "llama3.2-3b-instruct": { 33 | "llm_family": "llama3", "llm_cls": LlamaForCausalLM, "hf_hub_path": "Llama-3.2-3B-Instruct" 34 | }, 35 | } 36 | # fmt: on 37 | 38 | 39 | class LLaMa3LLMBackbone(HFCausalLLMBackbone): 40 | def __init__( 41 | self, 42 | llm_backbone_id: str, 43 | llm_max_length: int = 8192, 44 | mount_path: Optional[str] = None, 45 | inference_mode: bool = False, 46 | use_flash_attention_2: bool = True, 47 | ) -> None: 48 | super().__init__( 49 | llm_backbone_id, 50 | llm_max_length=llm_max_length, 51 | mount_path=mount_path, 52 | inference_mode=inference_mode, 53 | use_flash_attention_2=use_flash_attention_2, 54 | **LLAMA3_MODELS[llm_backbone_id], 55 | ) 56 | 57 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 58 | # Weizhi for new project: we did not need the added padding token 59 | self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<|pad|>"]}) 60 | self.tokenizer.pad_token = "<|pad|>" 61 | self.llm.resize_token_embeddings(len(self.tokenizer)) 62 | 63 | @property 64 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 65 | return LLaMa3ChatPromptBuilder 66 | 67 | @property 68 | def transformer_layer_cls(self) -> Type[nn.Module]: 69 | return LlamaDecoderLayer 70 | 71 | @property 72 | def half_precision_dtype(self) -> torch.dtype: 73 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" 74 | return torch.bfloat16 75 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/mistral.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi3.py 3 | 4 | Class definition for all LLMs derived from MistralForCausalLM. 5 | """ 6 | from typing import Optional, Type 7 | 8 | import torch 9 | from torch import nn as nn 10 | from transformers import MistralForCausalLM 11 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 12 | 13 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 14 | from prismatic.models.backbones.llm.prompting import ( 15 | LLaMa2ChatPromptBuilder, 16 | PromptBuilder, 17 | PurePromptBuilder, 18 | VicunaV15ChatPromptBuilder, 19 | ) 20 | 21 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 22 | # fmt: off 23 | MISTAL_MODELS = { 24 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 25 | "mistral-7b": { 26 | "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "Mistral-7B-v0.1" 27 | }, 28 | } 29 | # fmt: on 30 | 31 | 32 | class MistralLLMBackbone(HFCausalLLMBackbone): 33 | def __init__( 34 | self, 35 | llm_backbone_id: str, 36 | llm_max_length: int = 4096, 37 | mount_path: Optional[str] = None, 38 | inference_mode: bool = False, 39 | use_flash_attention_2: bool = True, 40 | ) -> None: 41 | super().__init__( 42 | llm_backbone_id, 43 | llm_max_length=llm_max_length, 44 | mount_path=mount_path, 45 | inference_mode=inference_mode, 46 | use_flash_attention_2=use_flash_attention_2, 47 | **MISTAL_MODELS[llm_backbone_id], 48 | ) 49 | 50 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 51 | # Weizhi for new project: we did not need the added padding token 52 | # self.tokenizer.add_special_tokens({"pad_token": ""}) 53 | # self.llm.config.pad_token_id = self.tokenizer.pad_token_id 54 | # self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) 55 | 56 | @property 57 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 58 | return PurePromptBuilder 59 | if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): 60 | return PurePromptBuilder 61 | 62 | elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): 63 | return LLaMa2ChatPromptBuilder 64 | 65 | elif self.identifier.startswith("vicuna"): 66 | return VicunaV15ChatPromptBuilder 67 | 68 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 69 | 70 | @property 71 | def transformer_layer_cls(self) -> Type[nn.Module]: 72 | return MistralDecoderLayer 73 | 74 | @property 75 | def half_precision_dtype(self) -> torch.dtype: 76 | """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" 77 | return torch.bfloat16 78 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/phi3.py: -------------------------------------------------------------------------------- 1 | """ 2 | phi3.py 3 | 4 | Class definition for all LLMs derived from Phi3ForCausalLM. 5 | """ 6 | from typing import Optional, Type 7 | 8 | import torch 9 | from torch import nn as nn 10 | from transformers import Phi3ForCausalLM 11 | from transformers.models.phi3.modeling_phi3 import Phi3DecoderLayer 12 | 13 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 14 | from prismatic.models.backbones.llm.prompting import ( 15 | LLaMa2ChatPromptBuilder, 16 | PromptBuilder, 17 | PurePromptBuilder, 18 | VicunaV15ChatPromptBuilder, 19 | Phi3PromptBuilder, 20 | ) 21 | 22 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 23 | # fmt: off 24 | PHI3_MODELS = { 25 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 26 | "phi3-3b": { 27 | "llm_family": "phi3", "llm_cls": Phi3ForCausalLM, "hf_hub_path": "Phi-3-mini-4k-instruct" 28 | }, 29 | "phi3.5-3b": { 30 | "llm_family": "phi3", "llm_cls": Phi3ForCausalLM, "hf_hub_path": "Phi-3.5-mini-instruct" 31 | }, 32 | } 33 | # fmt: on 34 | 35 | 36 | class Phi3LLMBackbone(HFCausalLLMBackbone): 37 | def __init__( 38 | self, 39 | llm_backbone_id: str, 40 | llm_max_length: int = 4096, 41 | mount_path: Optional[str] = None, 42 | inference_mode: bool = False, 43 | use_flash_attention_2: bool = True, 44 | ) -> None: 45 | super().__init__( 46 | llm_backbone_id, 47 | llm_max_length=llm_max_length, 48 | mount_path=mount_path, 49 | inference_mode=inference_mode, 50 | use_flash_attention_2=use_flash_attention_2, 51 | **PHI3_MODELS[llm_backbone_id], 52 | ) 53 | 54 | # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) 55 | self.tokenizer.pad_token = self.tokenizer.unk_token 56 | self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>"]}) 57 | 58 | @property 59 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 60 | return Phi3PromptBuilder 61 | # if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): 62 | # return PurePromptBuilder 63 | 64 | # elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): 65 | # return LLaMa2ChatPromptBuilder 66 | 67 | # elif self.identifier.startswith("vicuna"): 68 | # return VicunaV15ChatPromptBuilder 69 | 70 | raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") 71 | 72 | @property 73 | def transformer_layer_cls(self) -> Type[nn.Module]: 74 | return Phi3DecoderLayer 75 | 76 | @property 77 | def half_precision_dtype(self) -> torch.dtype: 78 | """Phi-3 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/phi3.""" 79 | return torch.bfloat16 80 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_prompter import PromptBuilder, PurePromptBuilder, QAPromptBuilder 2 | from .llama2_chat_prompter import LLaMa2ChatPromptBuilder 3 | from .llama3_chat_prompter import LLaMa3ChatPromptBuilder 4 | from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder 5 | from .phi_3_prompter import Phi3PromptBuilder 6 | from .qwen2_prompter import Qwen2PromptBuilder 7 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/base_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_prompter.py 3 | 4 | Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. 5 | """ 6 | from abc import ABC, abstractmethod 7 | from typing import Optional 8 | 9 | 10 | class PromptBuilder(ABC): 11 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 12 | self.model_family = model_family 13 | 14 | # Only some models define a system prompt => let subclasses handle this logic! 15 | self.system_prompt = system_prompt 16 | 17 | @abstractmethod 18 | def add_turn(self, role: str, message: str) -> str: ... 19 | 20 | @abstractmethod 21 | def get_potential_prompt(self, user_msg: str) -> None: ... 22 | 23 | @abstractmethod 24 | def get_prompt(self) -> str: ... 25 | 26 | 27 | class PurePromptBuilder(PromptBuilder): 28 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 29 | super().__init__(model_family, system_prompt) 30 | 31 | # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! 32 | self.bos, self.eos = "", "" 33 | 34 | # Get role-specific "wrap" functions 35 | self.wrap_human = lambda msg: f"In: {msg}\nOut: " 36 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 37 | 38 | # === `self.prompt` gets built up over multiple turns === 39 | self.prompt, self.turn_count = "", 0 40 | 41 | def add_turn(self, role: str, message: str) -> str: 42 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 43 | message = message.strip() #.replace("", "").strip() 44 | 45 | if (self.turn_count % 2) == 0: 46 | human_message = self.wrap_human(message) 47 | wrapped_message = human_message 48 | else: 49 | gpt_message = self.wrap_gpt(message) 50 | wrapped_message = gpt_message 51 | 52 | # Update Prompt 53 | self.prompt += wrapped_message 54 | 55 | # Bump Turn Counter 56 | self.turn_count += 1 57 | 58 | # Return "wrapped_message" (effective string added to context) 59 | return wrapped_message 60 | 61 | def get_potential_prompt(self, message: str) -> None: 62 | # Assumes that it's always the user's (human's) turn! 63 | prompt_copy = str(self.prompt) 64 | 65 | human_message = self.wrap_human(message) 66 | prompt_copy += human_message 67 | 68 | return prompt_copy.removeprefix(self.bos).rstrip() 69 | 70 | def get_prompt(self) -> str: 71 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 72 | return self.prompt.removeprefix(self.bos).rstrip() 73 | 74 | 75 | class QAPromptBuilder(PromptBuilder): 76 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 77 | super().__init__(model_family, system_prompt) 78 | 79 | # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! 80 | self.bos, self.eos = "", "" 81 | 82 | # Get role-specific "wrap" functions 83 | self.wrap_human = lambda msg: f"Question: {msg}\nAnswer: " 84 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}" 85 | 86 | # === `self.prompt` gets built up over multiple turns === 87 | self.prompt, self.turn_count = "", 0 88 | 89 | def add_turn(self, role: str, message: str) -> str: 90 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 91 | message = message.strip() #.replace("", "").strip() 92 | 93 | if (self.turn_count % 2) == 0: 94 | human_message = self.wrap_human(message) 95 | wrapped_message = human_message 96 | else: 97 | gpt_message = self.wrap_gpt(message) 98 | wrapped_message = gpt_message 99 | 100 | # Update Prompt 101 | self.prompt += wrapped_message 102 | 103 | # Bump Turn Counter 104 | self.turn_count += 1 105 | 106 | # Return "wrapped_message" (effective string added to context) 107 | return wrapped_message 108 | 109 | def get_potential_prompt(self, message: str) -> None: 110 | # Assumes that it's always the user's (human's) turn! 111 | prompt_copy = str(self.prompt) 112 | 113 | human_message = self.wrap_human(message) 114 | prompt_copy += human_message 115 | 116 | return prompt_copy.removeprefix(self.bos).rstrip() 117 | 118 | def get_prompt(self) -> str: 119 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 120 | return self.prompt.removeprefix(self.bos).rstrip() -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama2_prompter.py 3 | 4 | Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern 5 | that's used by HF and other online tutorials. 6 | 7 | Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 8 | """ 9 | from typing import Optional 10 | 11 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 12 | 13 | # Default System Prompt for Prismatic Models 14 | SYS_PROMPTS = { 15 | "prismatic": ( 16 | "You are a helpful language and vision assistant. " 17 | "You are able to understand the visual content that the user provides, " 18 | "and assist the user with a variety of tasks using natural language." 19 | ), 20 | } 21 | 22 | 23 | def format_system_prompt(system_prompt: str) -> str: 24 | return f"<\n{system_prompt.strip()}\n<>\n\n" 25 | 26 | 27 | class LLaMa2ChatPromptBuilder(PromptBuilder): 28 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 29 | super().__init__(model_family, system_prompt) 30 | self.system_prompt = format_system_prompt( 31 | SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt 32 | ) 33 | 34 | # LLaMa-2 Specific 35 | self.bos, self.eos = "", "" 36 | 37 | # Get role-specific "wrap" functions 38 | self.wrap_human = lambda msg: f"{self.bos}[INST] {msg} [/INST] " 39 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 40 | 41 | # === `self.prompt` gets built up over multiple turns === 42 | self.prompt, self.turn_count = "", 0 43 | 44 | def add_turn(self, role: str, message: str) -> str: 45 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 46 | message = message.strip() #.replace("", "").strip() 47 | 48 | # Special Handling for "system" prompt (turn_count == 0) 49 | if self.turn_count == 0: 50 | sys_message = self.wrap_human(self.system_prompt + message) 51 | wrapped_message = sys_message 52 | elif (self.turn_count % 2) == 0: 53 | human_message = self.wrap_human(message) 54 | wrapped_message = human_message 55 | else: 56 | gpt_message = self.wrap_gpt(message) 57 | wrapped_message = gpt_message 58 | 59 | # Update Prompt 60 | self.prompt += wrapped_message 61 | 62 | # Bump Turn Counter 63 | self.turn_count += 1 64 | 65 | # Return "wrapped_message" (effective string added to context) 66 | return wrapped_message 67 | 68 | def get_potential_prompt(self, message: str) -> None: 69 | # Assumes that it's always the user's (human's) turn! 70 | prompt_copy = str(self.prompt) 71 | 72 | # Special Handling for "system" prompt (turn_count == 0) 73 | if self.turn_count == 0: 74 | sys_message = self.wrap_human(self.system_prompt + message) 75 | prompt_copy += sys_message 76 | 77 | else: 78 | human_message = self.wrap_human(message) 79 | prompt_copy += human_message 80 | 81 | return prompt_copy.removeprefix(self.bos).rstrip() 82 | 83 | def get_prompt(self) -> str: 84 | # Remove prefix because it gets auto-inserted by tokenizer! 85 | return self.prompt.removeprefix(self.bos).rstrip() 86 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/llama3_chat_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | llama3_prompter.py 3 | 4 | Defines a PromptBuilder for building LLaMa-3 Chat Prompts --> not sure if this is "optimal", but this is the pattern 5 | that's used by HF and other online tutorials. 6 | 7 | <|begin_of_text|><|start_header_id|>system<|end_header_id|> 8 | 9 | You are a helpful, respectful, and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> 10 | 11 | Hi! I am a human.<|eot_id|><|start_header_id|>assistant<|end_header_id|> 12 | 13 | Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant<|eot_id|> 14 | """ 15 | from typing import Optional 16 | 17 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 18 | 19 | # Default System Prompt for Prismatic Models 20 | SYS_PROMPTS = { 21 | "prismatic": ( 22 | "You are a helpful language and vision assistant. " 23 | "You are able to understand the visual content that the user provides, " 24 | "and assist the user with a variety of tasks using natural language." 25 | ), 26 | } 27 | 28 | 29 | def format_system_prompt(system_prompt: str) -> str: 30 | return f"<|begin_of_text|><|start_header_id|>{system_prompt.strip()}<|end_header_id|>" 31 | 32 | 33 | class LLaMa3ChatPromptBuilder(PromptBuilder): 34 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 35 | super().__init__(model_family, system_prompt) 36 | self.system_prompt = """<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful, and honest assistant.<|eot_id|>""" 37 | 38 | # LLaMa-3 Specific 39 | self.bos, self.end_token = "<|begin_of_text|>", "<|eot_id|>" 40 | 41 | # Get role-specific "wrap" functions 42 | self.wrap_human = lambda msg: f"<|start_header_id|>user<|end_header_id|>\n\n{msg}{self.end_token}<|start_header_id|>assistant<|end_header_id|>\n\n" 43 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.end_token}" 44 | 45 | # === `self.prompt` gets built up over multiple turns === 46 | self.prompt, self.turn_count = "", 0 47 | 48 | def add_turn(self, role: str, message: str) -> str: 49 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 50 | message = message.strip() #.replace("", "").strip() 51 | 52 | # Special Handling for "system" prompt (turn_count == 0) 53 | if self.turn_count == 0: 54 | sys_message = self.system_prompt + self.wrap_human(message) 55 | wrapped_message = sys_message 56 | elif (self.turn_count % 2) == 0: 57 | human_message = self.wrap_human(message) 58 | wrapped_message = human_message 59 | else: 60 | gpt_message = self.wrap_gpt(message) 61 | wrapped_message = gpt_message 62 | 63 | # Update Prompt 64 | self.prompt += wrapped_message 65 | 66 | # Bump Turn Counter 67 | self.turn_count += 1 68 | 69 | # Return "wrapped_message" (effective string added to context) 70 | return wrapped_message 71 | 72 | def get_potential_prompt(self, message: str) -> None: 73 | # Assumes that it's always the user's (human's) turn! 74 | prompt_copy = str(self.prompt) 75 | 76 | # Special Handling for "system" prompt (turn_count == 0) 77 | if self.turn_count == 0: 78 | sys_message = self.wrap_human(self.system_prompt + message) 79 | prompt_copy += sys_message 80 | 81 | else: 82 | human_message = self.wrap_human(message) 83 | prompt_copy += human_message 84 | 85 | return prompt_copy.removeprefix(self.bos).rstrip() 86 | 87 | def get_prompt(self) -> str: 88 | # Remove prefix because it gets auto-inserted by tokenizer! 89 | return self.prompt.removeprefix(self.bos).rstrip() 90 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/phi_3_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | from typing import Optional 9 | 10 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 11 | 12 | # Default System Prompt for LLaVa Models 13 | SYS_PROMPTS = { 14 | "prismatic": ( 15 | "A chat between a curious user and an artificial intelligence assistant. " 16 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 17 | ), 18 | } 19 | 20 | 21 | class Phi3PromptBuilder(PromptBuilder): 22 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 23 | super().__init__(model_family, system_prompt) 24 | self.system_prompt = "" # (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " 25 | 26 | # LLaMa-2 Specific 27 | self.bos, self.eos = "", "<|endoftext|>" 28 | self.end_token = "<|end|>" 29 | 30 | # Get role-specific "wrap" functions 31 | self.wrap_human = lambda msg: f"<|user|>\n{msg}{self.end_token}\n<|assistant|>\n" 32 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.end_token}\n" 33 | 34 | # === `self.prompt` gets built up over multiple turns === 35 | self.prompt, self.turn_count = "", 0 36 | 37 | def add_turn(self, role: str, message: str) -> str: 38 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 39 | message = message.strip() #.replace("", "").strip() 40 | 41 | # Special Handling for "system" prompt (turn_count == 0) 42 | if self.turn_count == 0: 43 | sys_message = self.system_prompt + self.wrap_human(message) 44 | wrapped_message = sys_message 45 | elif (self.turn_count % 2) == 0: 46 | human_message = self.wrap_human(message) 47 | wrapped_message = human_message 48 | else: 49 | gpt_message = self.wrap_gpt(message) 50 | wrapped_message = gpt_message 51 | 52 | # Update Prompt 53 | self.prompt += wrapped_message 54 | 55 | # Bump Turn Counter 56 | self.turn_count += 1 57 | 58 | # Return "wrapped_message" (effective string added to context) 59 | return wrapped_message 60 | 61 | def get_potential_prompt(self, message: str) -> None: 62 | # Assumes that it's always the user's (human's) turn! 63 | prompt_copy = str(self.prompt) 64 | 65 | # Special Handling for "system" prompt (turn_count == 0) 66 | if self.turn_count == 0: 67 | sys_message = self.system_prompt + self.wrap_human(message) 68 | prompt_copy += sys_message 69 | 70 | else: 71 | human_message = self.wrap_human(message) 72 | prompt_copy += human_message 73 | 74 | return prompt_copy.removeprefix(self.bos).rstrip() 75 | 76 | def get_prompt(self) -> str: 77 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 78 | return self.prompt.removeprefix(self.bos).rstrip() 79 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/qwen2_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | from typing import Optional 9 | 10 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 11 | 12 | # Default System Prompt for LLaVa Models 13 | SYS_PROMPTS = { 14 | "prismatic": ( 15 | "A chat between a curious user and an artificial intelligence assistant. " 16 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 17 | ), 18 | } 19 | 20 | 21 | class Qwen2PromptBuilder(PromptBuilder): 22 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 23 | super().__init__(model_family, system_prompt) 24 | self.system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 25 | # LLaMa-2 Specific 26 | self.bos, self.eos = "", "<|im_end|>" 27 | 28 | # Get role-specific "wrap" functions 29 | self.wrap_human = lambda msg: f"<|im_start|>user\n{msg}<|im_end|>assistant\n" 30 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}\n" 31 | 32 | # === `self.prompt` gets built up over multiple turns === 33 | self.prompt, self.turn_count = "", 0 34 | 35 | def add_turn(self, role: str, message: str) -> str: 36 | # assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 37 | message = message.strip() #.replace("", "").strip() 38 | 39 | # Special Handling for "system" prompt (turn_count == 0) 40 | if self.turn_count == 0: 41 | sys_message = self.system_prompt + self.wrap_human(message) 42 | wrapped_message = sys_message 43 | elif (self.turn_count % 2) == 0: 44 | human_message = self.wrap_human(message) 45 | wrapped_message = human_message 46 | else: 47 | gpt_message = self.wrap_gpt(message) 48 | wrapped_message = gpt_message 49 | 50 | # Update Prompt 51 | self.prompt += wrapped_message 52 | 53 | # Bump Turn Counter 54 | self.turn_count += 1 55 | 56 | # Return "wrapped_message" (effective string added to context) 57 | return wrapped_message 58 | 59 | def get_potential_prompt(self, message: str) -> None: 60 | # Assumes that it's always the user's (human's) turn! 61 | prompt_copy = str(self.prompt) 62 | 63 | # Special Handling for "system" prompt (turn_count == 0) 64 | if self.turn_count == 0: 65 | sys_message = self.system_prompt + self.wrap_human(message) 66 | prompt_copy += sys_message 67 | 68 | else: 69 | human_message = self.wrap_human(message) 70 | prompt_copy += human_message 71 | 72 | # return prompt_copy.removeprefix(self.bos).rstrip() 73 | return prompt_copy.rstrip() 74 | 75 | def get_prompt(self) -> str: 76 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 77 | # return self.prompt.removeprefix(self.bos).rstrip() 78 | return self.prompt.rstrip() 79 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | vicuna_v15_prompter.py 3 | 4 | Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. 5 | 6 | Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 7 | """ 8 | from typing import Optional 9 | 10 | from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder 11 | 12 | # Default System Prompt for LLaVa Models 13 | SYS_PROMPTS = { 14 | "prismatic": ( 15 | "A chat between a curious user and an artificial intelligence assistant. " 16 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 17 | ), 18 | } 19 | 20 | 21 | class VicunaV15ChatPromptBuilder(PromptBuilder): 22 | def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: 23 | super().__init__(model_family, system_prompt) 24 | self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " 25 | 26 | # LLaMa-2 Specific 27 | self.bos, self.eos = "", "" 28 | 29 | # Get role-specific "wrap" functions 30 | self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " 31 | self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" 32 | 33 | # === `self.prompt` gets built up over multiple turns === 34 | self.prompt, self.turn_count = "", 0 35 | 36 | def add_turn(self, role: str, message: str) -> str: 37 | assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") 38 | message = message.strip() #.replace("", "").strip() 39 | 40 | # Special Handling for "system" prompt (turn_count == 0) 41 | if self.turn_count == 0: 42 | sys_message = self.system_prompt + self.wrap_human(message) 43 | wrapped_message = sys_message 44 | elif (self.turn_count % 2) == 0: 45 | human_message = self.wrap_human(message) 46 | wrapped_message = human_message 47 | else: 48 | gpt_message = self.wrap_gpt(message) 49 | wrapped_message = gpt_message 50 | 51 | # Update Prompt 52 | self.prompt += wrapped_message 53 | 54 | # Bump Turn Counter 55 | self.turn_count += 1 56 | 57 | # Return "wrapped_message" (effective string added to context) 58 | return wrapped_message 59 | 60 | def get_potential_prompt(self, message: str) -> None: 61 | # Assumes that it's always the user's (human's) turn! 62 | prompt_copy = str(self.prompt) 63 | 64 | # Special Handling for "system" prompt (turn_count == 0) 65 | if self.turn_count == 0: 66 | sys_message = self.system_prompt + self.wrap_human(message) 67 | prompt_copy += sys_message 68 | 69 | else: 70 | human_message = self.wrap_human(message) 71 | prompt_copy += human_message 72 | 73 | return prompt_copy.removeprefix(self.bos).rstrip() 74 | 75 | def get_prompt(self) -> str: 76 | # Remove prefix (if exists) because it gets auto-inserted by tokenizer! 77 | return self.prompt.removeprefix(self.bos).rstrip() 78 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/llm/qwen2.py: -------------------------------------------------------------------------------- 1 | """ 2 | qwen2.py 3 | 4 | Class definition for all LLMs derived from Qwen2ForCausalLM. 5 | """ 6 | from typing import Optional, Type 7 | 8 | import torch 9 | from torch import nn as nn 10 | from transformers import Qwen2ForCausalLM 11 | from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer 12 | 13 | from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone 14 | from prismatic.models.backbones.llm.prompting import ( 15 | PromptBuilder, 16 | Qwen2PromptBuilder, 17 | ) 18 | 19 | # Registry =>> Support LLaMa-2 Models (from HF Transformers) 20 | # fmt: off 21 | QWEN2_MODELS = { 22 | # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === 23 | "qwen2.5-1.5b": { 24 | "llm_family": "qwen2.5", "llm_cls": Qwen2ForCausalLM, "hf_hub_path": "Qwen2.5-1.5B" 25 | }, 26 | "qwen2.5-1.5b-instruct": { 27 | "llm_family": "qwen2.5", "llm_cls": Qwen2ForCausalLM, "hf_hub_path": "Qwen2.5-1.5B-Instruct" 28 | }, 29 | "qwen2.5-3b": { 30 | "llm_family": "qwen2.5", "llm_cls": Qwen2ForCausalLM, "hf_hub_path": "Qwen2.5-3B" 31 | }, 32 | "qwen2.5-7b-instruct": { 33 | "llm_family": "qwen2.5", "llm_cls": Qwen2ForCausalLM, "hf_hub_path": "Qwen2.5-7B-Instruct" 34 | }, 35 | } 36 | # fmt: on 37 | 38 | 39 | class Qwen2LLMBackbone(HFCausalLLMBackbone): 40 | def __init__( 41 | self, 42 | llm_backbone_id: str, 43 | llm_max_length: int = 4096, 44 | mount_path: Optional[str] = None, 45 | inference_mode: bool = False, 46 | use_flash_attention_2: bool = True, 47 | ) -> None: 48 | super().__init__( 49 | llm_backbone_id, 50 | llm_max_length=llm_max_length, 51 | mount_path=mount_path, 52 | inference_mode=inference_mode, 53 | use_flash_attention_2=use_flash_attention_2, 54 | **QWEN2_MODELS[llm_backbone_id], 55 | ) 56 | 57 | # [Special Case] Qwen-2.5 PAD Token Handling --> for clarity, we add an extra token, no need to resize the model embedding layer 58 | self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", ""]}) 59 | self.tokenizer.bos_token = "" 60 | 61 | @property 62 | def prompt_builder_fn(self) -> Type[PromptBuilder]: 63 | return Qwen2PromptBuilder 64 | 65 | @property 66 | def transformer_layer_cls(self) -> Type[nn.Module]: 67 | return Qwen2DecoderLayer 68 | 69 | @property 70 | def half_precision_dtype(self) -> torch.dtype: 71 | return torch.bfloat16 72 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_vision import ImageTransform, VisionBackbone 2 | from .clip_vit import CLIPViTBackbone 3 | from .dinoclip_vit import DinoCLIPViTBackbone 4 | from .dinosiglip_vit import DinoSigLIPViTBackbone 5 | from .dinov2_vit import DinoV2ViTBackbone 6 | from .in1k_vit import IN1KViTBackbone 7 | from .siglip_vit import SigLIPViTBackbone 8 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/vision/clip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | clip_vit.py 3 | """ 4 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 5 | 6 | # Registry =>> Supported CLIP Vision Backbones (from TIMM) 7 | CLIP_VISION_BACKBONES = { 8 | "clip-vit-b": "vit_base_patch16_clip_224.openai", 9 | "clip-vit-l": "vit_large_patch14_clip_224.openai", 10 | "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", 11 | } 12 | 13 | 14 | # [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. 15 | # HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's 16 | # a decent approximation, the resulting features are *worse*; this was a super tricky bug 17 | # to identify, but luckily there's an easy fix (`override_act_layer`) 18 | class CLIPViTBackbone(TimmViTBackbone): 19 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 20 | super().__init__( 21 | vision_backbone_id, 22 | CLIP_VISION_BACKBONES[vision_backbone_id], 23 | image_resize_strategy, 24 | default_image_size=default_image_size, 25 | override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, 26 | ) 27 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/vision/dinoclip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinoclip_vit.py 3 | 4 | Vision backbone that returns concatenated features from both DINOv2 and CLIP. 5 | """ 6 | 7 | from dataclasses import dataclass 8 | from functools import partial 9 | from typing import Callable, Dict, Tuple 10 | 11 | import timm 12 | import torch 13 | from PIL import Image 14 | from timm.models.vision_transformer import Block, VisionTransformer 15 | from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy 16 | from torchvision.transforms import Compose, Resize 17 | 18 | from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple 19 | 20 | # Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) 21 | DINOCLIP_VISION_BACKBONES = { 22 | "dinoclip-vit-l-336px": { 23 | "dino": "vit_large_patch14_reg4_dinov2.lvd142m", 24 | "clip": "vit_large_patch14_clip_336.openai", 25 | }, 26 | } 27 | 28 | 29 | @dataclass 30 | class DinoCLIPImageTransform: 31 | dino_image_transform: ImageTransform 32 | clip_image_transform: ImageTransform 33 | is_prismatic: bool = True 34 | 35 | def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: 36 | return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} 37 | 38 | 39 | class DinoCLIPViTBackbone(VisionBackbone): 40 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 41 | super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) 42 | self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] 43 | self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] 44 | 45 | # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary 46 | self.dino_featurizer: VisionTransformer = timm.create_model( 47 | self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 48 | ) 49 | self.dino_featurizer.eval() 50 | 51 | self.clip_featurizer: VisionTransformer = timm.create_model( 52 | self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size 53 | ) 54 | self.clip_featurizer.eval() 55 | 56 | # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility 57 | # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! 58 | # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 59 | self.dino_featurizer.forward = unpack_tuple( 60 | partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) 61 | ) 62 | self.clip_featurizer.forward = unpack_tuple( 63 | partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) 64 | ) 65 | 66 | # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models 67 | self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) 68 | self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 69 | 70 | self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) 71 | self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) 72 | 73 | # Initialize *both* Transforms 74 | default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) 75 | default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) 76 | if self.image_resize_strategy == "resize-naive": 77 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" 78 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" 79 | assert isinstance(default_dino_transform.transforms[0], Resize) 80 | assert isinstance(default_clip_transform.transforms[0], Resize) 81 | 82 | target_size = (self.default_image_size, self.default_image_size) 83 | dino_transform = Compose( 84 | [ 85 | Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), 86 | *default_dino_transform.transforms[1:], 87 | ] 88 | ) 89 | clip_transform = Compose( 90 | [ 91 | Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), 92 | *default_clip_transform.transforms[1:], 93 | ] 94 | ) 95 | 96 | self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) 97 | 98 | elif self.image_resize_strategy == "resize-crop": 99 | self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) 100 | 101 | elif self.image_resize_strategy == "letterbox": 102 | assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" 103 | assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" 104 | assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" 105 | 106 | # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) 107 | dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) 108 | clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) 109 | 110 | # Build New Transform 111 | self.image_transform = DinoCLIPImageTransform( 112 | Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), 113 | Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), 114 | ) 115 | 116 | else: 117 | raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") 118 | 119 | def get_fsdp_wrapping_policy(self) -> Callable: 120 | """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" 121 | vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) 122 | transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) 123 | return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) 124 | 125 | def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: 126 | """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" 127 | dino_patches = self.dino_featurizer(pixel_values["dino"]) 128 | clip_patches = self.clip_featurizer(pixel_values["clip"]) 129 | 130 | return torch.cat([dino_patches, clip_patches], dim=2) 131 | 132 | @property 133 | def default_image_resolution(self) -> Tuple[int, int, int]: 134 | return self.dino_data_cfg["input_size"] 135 | 136 | @property 137 | def embed_dim(self) -> int: 138 | return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim 139 | 140 | @property 141 | def num_patches(self) -> int: 142 | assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches 143 | return self.dino_featurizer.patch_embed.num_patches 144 | 145 | @property 146 | def half_precision_dtype(self) -> torch.dtype: 147 | return torch.bfloat16 -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/vision/dinov2_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | dinov2_vit.py 3 | """ 4 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 5 | 6 | # Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! 7 | # => Reference: https://arxiv.org/abs/2309.16588 8 | DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} 9 | 10 | 11 | class DinoV2ViTBackbone(TimmViTBackbone): 12 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 13 | super().__init__( 14 | vision_backbone_id, 15 | DINOv2_VISION_BACKBONES[vision_backbone_id], 16 | image_resize_strategy, 17 | default_image_size=default_image_size, 18 | ) 19 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/vision/in1k_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | in1k_vit.py 3 | 4 | Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) 5 | """ 6 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 7 | 8 | # Registry =>> Supported Vision Backbones (from TIMM) 9 | IN1K_VISION_BACKBONES = { 10 | "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", 11 | } 12 | 13 | 14 | class IN1KViTBackbone(TimmViTBackbone): 15 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 16 | super().__init__( 17 | vision_backbone_id, 18 | IN1K_VISION_BACKBONES[vision_backbone_id], 19 | image_resize_strategy, 20 | default_image_size=default_image_size, 21 | ) 22 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/backbones/vision/siglip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | siglip_vit.py 3 | """ 4 | from prismatic.models.backbones.vision.base_vision import TimmViTBackbone 5 | 6 | # Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) 7 | SIGLIP_VISION_BACKBONES = { 8 | "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", 9 | "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", 10 | "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", 11 | "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", 12 | "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", 13 | } 14 | 15 | 16 | class SigLIPViTBackbone(TimmViTBackbone): 17 | def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: 18 | super().__init__( 19 | vision_backbone_id, 20 | SIGLIP_VISION_BACKBONES[vision_backbone_id], 21 | image_resize_strategy, 22 | default_image_size=default_image_size, 23 | ) 24 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/load.py: -------------------------------------------------------------------------------- 1 | """ 2 | load.py 3 | 4 | Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical 5 | IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub). 6 | """ 7 | import json 8 | import os 9 | from pathlib import Path 10 | from typing import List, Optional, Union 11 | 12 | from huggingface_hub import hf_hub_download 13 | 14 | from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform 15 | from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY 16 | from prismatic.models.vlms import PrismaticVLM 17 | from prismatic.overwatch import initialize_overwatch 18 | 19 | # Initialize Overwatch =>> Wraps `logging.Logger` 20 | overwatch = initialize_overwatch(__name__) 21 | 22 | 23 | # === HF Hub Repository === 24 | HF_HUB_REPO = "weizhiwang" 25 | 26 | 27 | # === Available Models === 28 | def available_model_ids() -> List[str]: 29 | return list(MODEL_REGISTRY.keys()) 30 | 31 | 32 | def available_model_ids_and_names() -> List[List[str]]: 33 | return list(GLOBAL_REGISTRY.values()) 34 | 35 | 36 | def get_model_description(model_id_or_name: str) -> str: 37 | if model_id_or_name not in GLOBAL_REGISTRY: 38 | raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`") 39 | 40 | # Print Description & Return 41 | print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2)) 42 | 43 | return description 44 | 45 | 46 | # === Load Pretrained Model === 47 | def load( 48 | model_id_or_path: Union[str, Path], hf_token: Optional[str] = None, cache_dir: Optional[Union[str, Path]] = None 49 | ) -> PrismaticVLM: 50 | """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub.""" 51 | if os.path.isdir(model_id_or_path): 52 | overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`") 53 | 54 | # Get paths for `config.json` and pretrained checkpoint 55 | config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" 56 | assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" 57 | assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" 58 | else: 59 | if model_id_or_path not in GLOBAL_REGISTRY: 60 | raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`") 61 | 62 | overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub") 63 | config_json = hf_hub_download(repo_id=f"{HF_HUB_REPO}/{model_id}", filename="config.json", cache_dir=cache_dir) 64 | checkpoint_pt = hf_hub_download( 65 | repo_id=f"{HF_HUB_REPO}/{model_id}", filename="checkpoints/latest-checkpoint.pt", cache_dir=cache_dir 66 | ) 67 | 68 | # Load Model Config from `config.json` 69 | with open(config_json, "r") as f: 70 | model_cfg = json.load(f)["model"] 71 | with open(config_json, "r") as f: 72 | mount_path = json.load(f)["mount_path"] 73 | 74 | # = Load Individual Components necessary for Instantiating a VLM = 75 | # =>> Print Minimal Config 76 | overwatch.info( 77 | f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n" 78 | f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n" 79 | f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n" 80 | f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n" 81 | f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" 82 | ) 83 | 84 | # Load Vision Backbone 85 | overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]") 86 | vision_backbone, image_transform = get_vision_backbone_and_transform( 87 | model_cfg["vision_backbone_id"], 88 | model_cfg["image_resize_strategy"], 89 | ) 90 | 91 | # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` 92 | overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers") 93 | llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( 94 | model_cfg["llm_backbone_id"], 95 | llm_max_length=model_cfg.get("llm_max_length", 4096), 96 | # hf_token=hf_token, 97 | mount_path=mount_path, 98 | inference_mode=True, 99 | ) 100 | 101 | # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) 102 | overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint; Freezing Weights 🥶") 103 | vlm = PrismaticVLM.from_pretrained( 104 | checkpoint_pt, 105 | model_cfg["model_id"], 106 | vision_backbone, 107 | llm_backbone, 108 | arch_specifier=model_cfg["arch_specifier"], 109 | ) 110 | 111 | return vlm 112 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports 5 | individual functions for clear control flow. 6 | """ 7 | from typing import Optional, Tuple 8 | 9 | from transformers import PreTrainedTokenizerBase 10 | 11 | from prismatic.models.backbones.llm import LLMBackbone, LLaMa2LLMBackbone, LLaMa3LLMBackbone, MistralLLMBackbone, Phi3LLMBackbone, Qwen2LLMBackbone 12 | from prismatic.models.backbones.vision import ( 13 | CLIPViTBackbone, 14 | DinoCLIPViTBackbone, 15 | DinoSigLIPViTBackbone, 16 | DinoV2ViTBackbone, 17 | ImageTransform, 18 | IN1KViTBackbone, 19 | SigLIPViTBackbone, 20 | VisionBackbone, 21 | ) 22 | from prismatic.models.vlms import PrismaticVLM 23 | 24 | # === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === 25 | # fmt: off 26 | 27 | # === Vision Backbone Registry === 28 | VISION_BACKBONES = { 29 | # === 224px Backbones === 30 | "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 31 | "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 32 | "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, 33 | "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, 34 | "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 35 | 36 | # === Assorted CLIP Backbones === 37 | "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 38 | "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 39 | 40 | # === Assorted SigLIP Backbones === 41 | "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, 42 | "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, 43 | "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 44 | "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 45 | 46 | # === Fused Backbones === 47 | "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, 48 | "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, 49 | } 50 | 51 | 52 | # === Language Model Registry === 53 | LLM_BACKBONES = { 54 | # === LLaMa-2 Pure (Non-Chat) Backbones === 55 | "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 56 | "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 57 | 58 | # === LLaMa-2 Chat Backbones === 59 | "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 60 | "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 61 | 62 | # === Vicuna-v1.5 Backbones === 63 | "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 64 | "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 65 | 66 | # === LLaMa-3 Pure (Non-Chat) Backbones === 67 | "llama3-8b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 68 | "llama3-70b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, 69 | 70 | # === LLaMa-3.2 Backbones === 71 | "llama3.1-8b-pure": {"cls": LLaMa3LLMBackbone, "kwargs": {}}, 72 | "llama3.1-8b-instruct": {"cls": LLaMa3LLMBackbone, "kwargs": {}}, 73 | 74 | # === LLaMa-3.2 Backbones === 75 | "llama3.2-3b-pure": {"cls": LLaMa3LLMBackbone, "kwargs": {}}, 76 | "llama3.2-3b-instruct": {"cls": LLaMa3LLMBackbone, "kwargs": {}}, 77 | 78 | # === Mistral-v0.1 Backbones === 79 | "mistral-7b": {"cls": MistralLLMBackbone, "kwargs": {}}, 80 | 81 | # === Phi3-v0.1 Backbones === 82 | "phi3-3b": {"cls": Phi3LLMBackbone, "kwargs": {}}, 83 | "phi3.5-3b": {"cls": Phi3LLMBackbone, "kwargs": {}}, 84 | 85 | # === Qwen2.5 Backbones === 86 | "qwen2.5-3b": {"cls": Qwen2LLMBackbone, "kwargs": {}}, 87 | "qwen2.5-1.5b": {"cls": Qwen2LLMBackbone, "kwargs": {}}, 88 | "qwen2.5-1.5b-instruct": {"cls": Qwen2LLMBackbone, "kwargs": {}}, 89 | "qwen2.5-7b-instruct": {"cls": Qwen2LLMBackbone, "kwargs": {}}, 90 | } 91 | 92 | # fmt: on 93 | 94 | 95 | def get_vision_backbone_and_transform( 96 | vision_backbone_id: str, image_resize_strategy: str 97 | ) -> Tuple[VisionBackbone, ImageTransform]: 98 | """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" 99 | if vision_backbone_id in VISION_BACKBONES: 100 | vision_cfg = VISION_BACKBONES[vision_backbone_id] 101 | vision_backbone: VisionBackbone = vision_cfg["cls"]( 102 | vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] 103 | ) 104 | image_transform = vision_backbone.get_image_transform() 105 | return vision_backbone, image_transform 106 | 107 | else: 108 | raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") 109 | 110 | 111 | def get_llm_backbone_and_tokenizer( 112 | llm_backbone_id: str, 113 | llm_max_length: int = 4096, 114 | mount_path: Optional[str] = None, 115 | inference_mode: bool = False, 116 | ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: 117 | if llm_backbone_id in LLM_BACKBONES: 118 | llm_cfg = LLM_BACKBONES[llm_backbone_id] 119 | llm_backbone: LLMBackbone = llm_cfg["cls"]( 120 | llm_backbone_id, 121 | llm_max_length=llm_max_length, 122 | mount_path=mount_path, 123 | inference_mode=inference_mode, 124 | **llm_cfg["kwargs"], 125 | ) 126 | tokenizer = llm_backbone.get_tokenizer() 127 | return llm_backbone, tokenizer 128 | 129 | else: 130 | raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") 131 | 132 | 133 | def get_vlm( 134 | model_id: str, 135 | arch_specifier: str, 136 | vision_backbone: VisionBackbone, 137 | llm_backbone: LLMBackbone, 138 | enable_mixed_precision_training: bool = True, 139 | ) -> PrismaticVLM: 140 | """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" 141 | return PrismaticVLM( 142 | model_id, 143 | vision_backbone, 144 | llm_backbone, 145 | enable_mixed_precision_training=enable_mixed_precision_training, 146 | arch_specifier=arch_specifier, 147 | ) 148 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | registry.py 3 | 4 | Exhaustive list of pretrained VLMs (with full descriptions / links to corresponding names and sections of paper). 5 | """ 6 | 7 | 8 | # === Pretrained Model Registry === 9 | # fmt: off 10 | MODEL_REGISTRY = { 11 | # === LLaVa v1.5 Reproductions === 12 | "Open-Qwen2VL": { 13 | "model_id": "Open-Qwen2VL", 14 | "names": ["Open-Qwen2VL-2B-Instruct"], 15 | "description": { 16 | "name": "Open-Qwen2VL-2B-Instruct", 17 | "optimization_procedure": "2-stage", 18 | "visual_representation": "SigLIP ViT-SO/14 @ 384p", 19 | "image_processing": "Naive", 20 | "language_model": "Qwen2.5 1.5B Instruct", 21 | "train_epochs": 1, 22 | } 23 | }, 24 | "Open-Qwen2VL-base": { 25 | "model_id": "Open-Qwen2VL-base", 26 | "names": ["Open-Qwen2VL-2B-base"], 27 | "description": { 28 | "name": "Open-Qwen2VL-2B-Instruct", 29 | "optimization_procedure": "2-stage", 30 | "visual_representation": "SigLIP ViT-SO/14 @ 384p", 31 | "image_processing": "Naive", 32 | "language_model": "Qwen2.5 1.5B Instruct", 33 | "train_epochs": 1, 34 | } 35 | } 36 | } 37 | 38 | # Build Global Registry (Model ID, Name) -> Metadata 39 | GLOBAL_REGISTRY = {name: v for k, v in MODEL_REGISTRY.items() for name in [k] + v["names"]} 40 | 41 | # fmt: on 42 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/vlms/__init__.py: -------------------------------------------------------------------------------- 1 | from .prismatic import PrismaticVLM 2 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/models/vlms/base_vlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_vlm.py 3 | 4 | Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, 5 | and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate 6 | from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, 7 | PALI, Fuyu) in the future. 8 | 9 | We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance 10 | (e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), 11 | prefer Protocol definitions instead. 12 | """ 13 | from __future__ import annotations 14 | 15 | from abc import ABC, abstractmethod 16 | from pathlib import Path 17 | from typing import Callable, List, Optional 18 | 19 | import torch 20 | import torch.nn as nn 21 | from transformers import GenerationMixin, PretrainedConfig 22 | from transformers.modeling_outputs import CausalLMOutputWithPast 23 | 24 | from prismatic.models.backbones.llm import LLMBackbone 25 | from prismatic.models.backbones.llm.prompting import PromptBuilder 26 | from prismatic.models.backbones.vision import VisionBackbone 27 | 28 | 29 | # === Abstract Base Class for arbitrary Vision-Language Models === 30 | class VLM(nn.Module, GenerationMixin, ABC): 31 | def __init__( 32 | self, 33 | model_family: str, 34 | model_id: str, 35 | vision_backbone: VisionBackbone, 36 | llm_backbone: LLMBackbone, 37 | enable_mixed_precision_training: bool = True, 38 | ) -> None: 39 | super().__init__() 40 | self.model_family, self.model_id = model_family, model_id 41 | self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone 42 | self.enable_mixed_precision_training = enable_mixed_precision_training 43 | 44 | # Instance Attributes for a generic VLM 45 | self.all_module_keys, self.trainable_module_keys = None, None 46 | 47 | # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === 48 | self.generation_config = self.llm_backbone.llm.generation_config 49 | self.main_input_name = "input_ids" 50 | 51 | @property 52 | def device(self) -> torch.device: 53 | """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" 54 | return next(self.parameters()).device 55 | 56 | @classmethod 57 | @abstractmethod 58 | def from_pretrained( 59 | cls, 60 | pretrained_checkpoint: Path, 61 | model_family: str, 62 | model_id: str, 63 | vision_backbone: VisionBackbone, 64 | llm_backbone: LLMBackbone, 65 | **kwargs: str, 66 | ) -> VLM: ... 67 | 68 | @abstractmethod 69 | def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... 70 | 71 | @abstractmethod 72 | def freeze_backbones(self, stage: str) -> None: ... 73 | 74 | @abstractmethod 75 | def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... 76 | 77 | @abstractmethod 78 | def get_fsdp_wrapping_policy(self) -> Callable: ... 79 | 80 | @abstractmethod 81 | def forward( 82 | self, 83 | input_ids: Optional[torch.LongTensor] = None, 84 | attention_mask: Optional[torch.Tensor] = None, 85 | pixel_values: Optional[torch.FloatTensor] = None, 86 | labels: Optional[torch.LongTensor] = None, 87 | inputs_embeds: Optional[torch.FloatTensor] = None, 88 | past_key_values: Optional[List[torch.FloatTensor]] = None, 89 | use_cache: Optional[bool] = None, 90 | output_attentions: Optional[bool] = None, 91 | output_hidden_states: Optional[bool] = None, 92 | return_dict: Optional[bool] = None, 93 | multimodal_indices: Optional[torch.LongTensor] = None, 94 | ) -> CausalLMOutputWithPast: ... 95 | 96 | # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === 97 | @staticmethod 98 | def can_generate() -> bool: 99 | return True 100 | 101 | @property 102 | def config(self) -> PretrainedConfig: 103 | return self.llm_backbone.llm.config 104 | 105 | # => Beam Search Utility 106 | def _reorder_cache(self, past_key_values, beam_idx): 107 | return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) 108 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/overwatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .overwatch import initialize_overwatch 2 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/overwatch/overwatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | overwatch.py 3 | 4 | Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. 5 | """ 6 | 7 | import logging 8 | import logging.config 9 | import os 10 | from contextlib import nullcontext 11 | from logging import LoggerAdapter 12 | from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union 13 | 14 | # Overwatch Default Format String 15 | RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" 16 | 17 | # Set Logging Configuration 18 | LOG_CONFIG = { 19 | "version": 1, 20 | "disable_existing_loggers": True, 21 | "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, 22 | "handlers": { 23 | "console": { 24 | "class": "rich.logging.RichHandler", 25 | "formatter": "simple-console", 26 | "markup": True, 27 | "rich_tracebacks": True, 28 | "show_level": True, 29 | "show_path": True, 30 | "show_time": True, 31 | } 32 | }, 33 | "root": {"level": "INFO", "handlers": ["console"]}, 34 | } 35 | logging.config.dictConfig(LOG_CONFIG) 36 | 37 | 38 | # === Custom Contextual Logging Logic === 39 | class ContextAdapter(LoggerAdapter): 40 | CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} 41 | 42 | def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: 43 | ctx_level = kwargs.pop("ctx_level", 0) 44 | return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs 45 | 46 | 47 | class DistributedOverwatch: 48 | def __init__(self, name: str) -> None: 49 | """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" 50 | from accelerate import PartialState 51 | 52 | # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` 53 | # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! 54 | self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() 55 | 56 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 57 | self.debug = self.logger.debug 58 | self.info = self.logger.info 59 | self.warning = self.logger.warning 60 | self.error = self.logger.error 61 | self.critical = self.logger.critical 62 | 63 | # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! 64 | self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) 65 | 66 | @property 67 | def rank_zero_only(self) -> Callable[..., Any]: 68 | return self.distributed_state.on_main_process 69 | 70 | @property 71 | def local_zero_only(self) -> Callable[..., Any]: 72 | return self.distributed_state.on_local_main_process 73 | 74 | @property 75 | def rank_zero_first(self) -> Callable[..., Any]: 76 | return self.distributed_state.main_process_first 77 | 78 | @property 79 | def local_zero_first(self) -> Callable[..., Any]: 80 | return self.distributed_state.local_main_process_first 81 | 82 | def is_rank_zero(self) -> bool: 83 | return self.distributed_state.is_main_process 84 | 85 | def rank(self) -> int: 86 | return self.distributed_state.process_index 87 | 88 | def local_rank(self) -> int: 89 | return self.distributed_state.local_process_index 90 | 91 | def world_size(self) -> int: 92 | return self.distributed_state.num_processes 93 | 94 | 95 | class PureOverwatch: 96 | def __init__(self, name: str) -> None: 97 | """Initializer for an Overwatch object that just wraps logging.""" 98 | self.logger = ContextAdapter(logging.getLogger(name), extra={}) 99 | 100 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 101 | self.debug = self.logger.debug 102 | self.info = self.logger.info 103 | self.warning = self.logger.warning 104 | self.error = self.logger.error 105 | self.critical = self.logger.critical 106 | 107 | # Logging Defaults =>> INFO 108 | self.logger.setLevel(logging.INFO) 109 | 110 | @staticmethod 111 | def get_identity_ctx() -> Callable[..., Any]: 112 | def identity(fn: Callable[..., Any]) -> Callable[..., Any]: 113 | return fn 114 | 115 | return identity 116 | 117 | @property 118 | def rank_zero_only(self) -> Callable[..., Any]: 119 | return self.get_identity_ctx() 120 | 121 | @property 122 | def local_zero_only(self) -> Callable[..., Any]: 123 | return self.get_identity_ctx() 124 | 125 | @property 126 | def rank_zero_first(self) -> Callable[..., Any]: 127 | return nullcontext 128 | 129 | @property 130 | def local_zero_first(self) -> Callable[..., Any]: 131 | return nullcontext 132 | 133 | @staticmethod 134 | def is_rank_zero() -> bool: 135 | return True 136 | 137 | @staticmethod 138 | def rank() -> int: 139 | return 0 140 | 141 | @staticmethod 142 | def world_size() -> int: 143 | return 1 144 | 145 | 146 | def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: 147 | return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .download import convert_to_jpg, download_extract 2 | from .materialize import get_dataset_and_collator, get_wds_datainfo 3 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/preprocessing/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import AlignDataset, FinetuneDataset, PreTrainDataset, FinetuneLargeDataset 2 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/prismatic-vlms/prismatic/py.typed -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .materialize import get_train_strategy 2 | from .metrics import Metrics 3 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/training/materialize.py: -------------------------------------------------------------------------------- 1 | """ 2 | materialize.py 3 | 4 | Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, 5 | and strategy configurations. 6 | """ 7 | from typing import Callable, Optional 8 | 9 | import torch 10 | 11 | from prismatic.models.vlms import PrismaticVLM 12 | from prismatic.training.strategies import FSDPStrategy, TrainingStrategy 13 | 14 | # Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! 15 | TRAIN_STRATEGIES = { 16 | "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, 17 | "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, 18 | } 19 | 20 | 21 | def get_train_strategy( 22 | train_strategy: str, 23 | vlm: PrismaticVLM, 24 | device_id: int, 25 | epochs: int, 26 | max_steps: Optional[int], 27 | global_batch_size: int, 28 | per_device_batch_size: int, 29 | learning_rate: float, 30 | weight_decay: float, 31 | max_grad_norm: float, 32 | lr_scheduler_type: str, 33 | warmup_ratio: float, 34 | enable_gradient_checkpointing: bool = True, 35 | enable_mixed_precision_training: bool = True, 36 | reduce_in_full_precision: bool = False, 37 | mixed_precision_dtype: torch.dtype = torch.bfloat16, 38 | worker_init_fn: Optional[Callable[[int], None]] = None, 39 | ) -> TrainingStrategy: 40 | if train_strategy in TRAIN_STRATEGIES: 41 | strategy_cfg = TRAIN_STRATEGIES[train_strategy] 42 | strategy = strategy_cfg["cls"]( 43 | vlm=vlm, 44 | device_id=device_id, 45 | epochs=epochs, 46 | max_steps=max_steps, 47 | global_batch_size=global_batch_size, 48 | per_device_batch_size=per_device_batch_size, 49 | learning_rate=learning_rate, 50 | weight_decay=weight_decay, 51 | max_grad_norm=max_grad_norm, 52 | lr_scheduler_type=lr_scheduler_type, 53 | warmup_ratio=warmup_ratio, 54 | enable_gradient_checkpointing=enable_gradient_checkpointing, 55 | enable_mixed_precision_training=enable_mixed_precision_training, 56 | reduce_in_full_precision=reduce_in_full_precision, 57 | mixed_precision_dtype=mixed_precision_dtype, 58 | worker_init_fn=worker_init_fn, 59 | **strategy_cfg["kwargs"], 60 | ) 61 | return strategy 62 | else: 63 | raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") 64 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/training/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_strategy import TrainingStrategy 2 | from .ddp import DDPStrategy 3 | from .fsdp import FSDPStrategy 4 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/training/strategies/ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | ddp.py 3 | 4 | Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most 5 | GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. 6 | """ 7 | 8 | import shutil 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.optim import AdamW 15 | from transformers.optimization import get_cosine_schedule_with_warmup 16 | 17 | from prismatic.overwatch import initialize_overwatch 18 | from prismatic.training.strategies.base_strategy import TrainingStrategy 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | class DDPStrategy(TrainingStrategy): 25 | @overwatch.rank_zero_only 26 | def save_checkpoint( 27 | self, 28 | run_dir: Path, 29 | global_step: int, 30 | epoch: int, 31 | train_loss: Optional[float] = None, 32 | only_trainable: bool = True, 33 | ) -> None: 34 | """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" 35 | assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" 36 | 37 | # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) 38 | model_state_dicts = { 39 | mkey: getattr(self.vlm.module, mkey).state_dict() 40 | for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) 41 | } 42 | optimizer_state_dict = self.optimizer.state_dict() 43 | 44 | # Set Checkpoint Path =>> Embed *minimal* training statistics! 45 | checkpoint_dir = run_dir / "checkpoints" 46 | if train_loss is None: 47 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" 48 | else: 49 | checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" 50 | 51 | # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` 52 | torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) 53 | shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") 54 | 55 | def run_setup(self, run_dir: Path, n_train_examples: int) -> None: 56 | # Gradient Checkpointing Setup 57 | if self.enable_gradient_checkpointing: 58 | # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up 59 | # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF 60 | # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` 61 | # on `self.llm_backbone`. 62 | # 63 | # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic 64 | # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 65 | # 66 | # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) 67 | # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb 68 | overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) 69 | self.vlm.llm_backbone.gradient_checkpointing_enable() 70 | 71 | # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) 72 | overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) 73 | self.vlm.to(self.device_id) 74 | 75 | # Wrap with Distributed Data Parallel 76 | # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that 77 | # is the same size/dtype as the model parameters; this will *double* GPU memory! 78 | # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel 79 | overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) 80 | self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) 81 | 82 | # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` 83 | # => Optimizer should only operate on parameters that are *unfrozen* / trainable! 84 | trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] 85 | if self.lr_scheduler_type == "linear-warmup+cosine-decay": 86 | if self.max_steps is None: 87 | num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size 88 | else: 89 | num_training_steps = self.max_steps 90 | 91 | # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) 92 | num_warmup_steps = int(num_training_steps * self.warmup_ratio) 93 | 94 | assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" 95 | self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) 96 | self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) 97 | for param_group in self.optimizer.param_groups: 98 | param_group["lr"] = 0.0 99 | 100 | else: 101 | raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") 102 | 103 | # Finalize Setup =>> Log 104 | overwatch.info( 105 | "DDP Strategy =>> Finalized Training Setup:\n" 106 | f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" 107 | f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" 108 | f" |-> Distributed World Size = {overwatch.world_size()}\n" 109 | f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" 110 | f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" 111 | f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" 112 | f" |-> Default AdamW LR = {self.learning_rate}\n" 113 | f" |-> AdamW Weight Decay = {self.weight_decay}\n" 114 | f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" 115 | f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" 116 | f" |-> Dataset Size = {n_train_examples} Examples\n" 117 | f" |-> Max Steps = {num_training_steps}\n" 118 | ) 119 | 120 | def clip_grad_norm(self) -> None: 121 | torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) 122 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_utils import check_bloat16_supported, set_global_seed 2 | -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/util/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | nn_utils.py 3 | 4 | Utility functions and PyTorch submodule definitions. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from einops import rearrange 11 | 12 | 13 | # === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === 14 | class LinearProjector(nn.Module): 15 | def __init__(self, vision_dim: int, llm_dim: int) -> None: 16 | super().__init__() 17 | self.projector = nn.Linear(vision_dim, llm_dim, bias=True) 18 | 19 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 20 | return self.projector(img_patches) 21 | 22 | 23 | class MLPProjector(nn.Module): 24 | def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: 25 | super().__init__() 26 | if mlp_type == "gelu-mlp": 27 | self.projector = nn.Sequential( 28 | nn.Linear(vision_dim, llm_dim, bias=True), 29 | nn.GELU(), 30 | nn.Linear(llm_dim, llm_dim, bias=True), 31 | ) 32 | else: 33 | raise ValueError(f"Projector with `{mlp_type = }` is not supported!") 34 | 35 | def forward(self, img_patches: torch.Tensor) -> torch.Tensor: 36 | return self.projector(img_patches) 37 | 38 | 39 | class FusedMLPProjector(nn.Module): 40 | def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: 41 | super().__init__() 42 | self.initial_projection_dim = fused_vision_dim * 4 43 | if mlp_type == "fused-gelu-mlp": 44 | self.projector = nn.Sequential( 45 | nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), 46 | nn.GELU(), 47 | nn.Linear(self.initial_projection_dim, llm_dim, bias=True), 48 | nn.GELU(), 49 | nn.Linear(llm_dim, llm_dim, bias=True), 50 | ) 51 | else: 52 | raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") 53 | 54 | def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: 55 | return self.projector(fused_img_patches) 56 | 57 | class AvgPoolProjector(nn.Module): 58 | def __init__( 59 | self, 60 | layer_num: int = 2, 61 | query_num: int = 144, 62 | mm_hidden_size: int = 1024, 63 | llm_hidden_size: int = 4096, 64 | ): 65 | super().__init__() 66 | self.layer_num = layer_num 67 | self.query_num = query_num 68 | self.mm_hidden_size = mm_hidden_size 69 | self.llm_hidden_size = llm_hidden_size 70 | self.build_net() 71 | 72 | def build_net(self): 73 | hw = int(self.query_num ** 0.5) 74 | sampler = nn.AdaptiveAvgPool2d((hw, hw)) 75 | self.sampler = sampler 76 | modules = [nn.Linear(self.mm_hidden_size, self.llm_hidden_size)] 77 | for _ in range(1, self.layer_num): 78 | modules.append(nn.GELU()) 79 | modules.append(nn.Linear(self.llm_hidden_size, self.llm_hidden_size)) 80 | self.mlp_projector = nn.Sequential(*modules) 81 | print(f"patch size {hw} average pooling layer initialized") 82 | 83 | def forward(self, visual_feat: torch.Tensor) -> torch.Tensor: 84 | batch_size, seq_len, h_dim = visual_feat.shape 85 | hw = int(seq_len ** 0.5) 86 | shaped_visual_feat = rearrange(visual_feat, "b (h w) d -> b d h w", h=hw, w=hw) # torch.Size([64, 1024, 24, 24]) 87 | pooled_visual_feat = self.sampler(shaped_visual_feat) # torch.Size([64, 1024, 12, 12]) 88 | reshaped_visual_feat = rearrange(pooled_visual_feat, "b d h w -> b (h w) d") # [64, 144, 1024] 89 | output_feat = self.mlp_projector(reshaped_visual_feat) # [64, 144, 4096]) 90 | return output_feat -------------------------------------------------------------------------------- /prismatic-vlms/prismatic/util/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | torch_utils.py 3 | 4 | General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. 5 | 6 | Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: 7 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py 8 | 9 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our 10 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime 11 | we inject randomness from non-PyTorch sources (e.g., numpy, random)! 12 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ 13 | 14 | Terminology 15 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! 16 | -> Rank :: Integer index of current process in the total world size 17 | -> Local Rank :: Local index on given node in [0, Devices per Node] 18 | """ 19 | import os 20 | import random 21 | from typing import Callable, Optional 22 | 23 | import numpy as np 24 | import torch 25 | 26 | # === Randomness === 27 | 28 | 29 | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: 30 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" 31 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" 32 | 33 | # Set Seed as an Environment Variable 34 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | 39 | return worker_init_function if get_worker_init_fn else None 40 | 41 | 42 | def worker_init_function(worker_id: int) -> None: 43 | """ 44 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: 45 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 46 | 47 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that 48 | you can run iterative splitting on to get new (predictable) randomness. 49 | 50 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. 51 | """ 52 | # Get current `rank` (if running distributed) and `process_seed` 53 | global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() 54 | 55 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: 56 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness 57 | base_seed = process_seed - worker_id 58 | 59 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... 60 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) 61 | 62 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! 63 | np.random.seed(seed_seq.generate_state(4)) 64 | 65 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random 66 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2) 67 | 68 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 69 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) 70 | 71 | # Use 128 Bits for `random`, but express as integer instead of as an array 72 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() 73 | random.seed(random_seed) 74 | 75 | 76 | # === BFloat16 Support === 77 | 78 | 79 | def check_bloat16_supported() -> bool: 80 | try: 81 | import packaging.version 82 | import torch.cuda.nccl as nccl 83 | import torch.distributed as dist 84 | 85 | return ( 86 | (torch.version.cuda is not None) 87 | and torch.cuda.is_bf16_supported() 88 | and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) 89 | and dist.is_nccl_available() 90 | and (nccl.version() >= (2, 10)) 91 | ) 92 | 93 | except Exception: 94 | return False 95 | -------------------------------------------------------------------------------- /prismatic-vlms/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "prismatic" 7 | authors = [ 8 | {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"}, 9 | {name = "Suraj Nair", email="suraj.nair@tri.global"}, 10 | {name = "Ashwin Balakrishna", email="ashwin.balakrishna@tri.global"}, 11 | ] 12 | description = "Prismatic VLMs: Investigating the Design Space of Visually-Conditioned Language Models" 13 | version = "0.0.2" 14 | readme = "README.md" 15 | requires-python = ">=3.10" 16 | keywords = ["vision-language models", "multimodal pretraining", "machine learning"] 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "Intended Audience :: Education", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ] 32 | dependencies = [ 33 | "accelerate>=0.25.0", 34 | "draccus>=0.7.3", 35 | "einops", 36 | # "flash_attn>=2.3.3", # Here for documentation -- install *AFTER* editable install (follow README) 37 | "huggingface_hub", 38 | "jsonlines", 39 | "rich", 40 | "timm==0.9.10", 41 | "transformers", 42 | "wandb", 43 | "webdataset", 44 | "sentencepiece", 45 | "braceexpand", 46 | ] 47 | 48 | [project.optional-dependencies] 49 | dev = [ 50 | "black>=24.2.0", 51 | "gpustat", 52 | "ipython", 53 | "pre-commit", 54 | "ruff>=0.2.2", 55 | ] 56 | 57 | [project.urls] 58 | homepage = "https://github.com/TRI-ML/prismatic-vlms" 59 | repository = "https://github.com/TRI-ML/prismatic-vlms" 60 | documentation = "https://github.com/TRI-ML/prismatic-vlms" 61 | 62 | [tool.setuptools.packages.find] 63 | where = ["."] 64 | exclude = ["cache"] 65 | 66 | [tool.setuptools.package-data] 67 | "prismatic" = ["py.typed"] 68 | 69 | [tool.black] 70 | line-length = 121 71 | target-version = ["py38", "py39", "py310"] 72 | preview = true 73 | 74 | [tool.ruff] 75 | line-length = 121 76 | target-version = "py38" 77 | select = ["A", "B", "E", "F", "I", "RUF", "W"] 78 | ignore = ["F722"] 79 | 80 | [tool.ruff.lint.per-file-ignores] 81 | "__init__.py" = ["E402", "F401"] 82 | -------------------------------------------------------------------------------- /prismatic-vlms/scripts/additional-datasets/lrv_instruct.py: -------------------------------------------------------------------------------- 1 | """ 2 | scripts/additional-datasets/lrv_instruct.py 3 | 4 | Standalone script for pre-processing the LRV-Instruct data (including the chart/diagram reasoning split). This isn't 5 | full conversational chat data, but rather each example has an input prompt and output response; we'll use this structure 6 | to format the data equivalently to the LLaVa-v1.5 dataset. 7 | 8 | In general, LRV Instruct provides *both positive and negative* examples -- where a negative example is a question or 9 | instruction that is *not answerable* or *irrelevant*; the goal of this dataset is to reduce hallucinations in VLMs. 10 | 11 | This script downloads the raw instruct data (three different JSON files), as well as the image files; the non-chart 12 | images come from Visual Genome, but are hosted separately by the LRV Instruct authors and use different image IDs, so 13 | we're downloading this data (again) for simplicity. The chart images come from the LRV Instruct authors, and are sourced 14 | from statista.com. All file URLS are here: https://github.com/FuxiaoLiu/LRV-Instruction/blob/main/download.txt#L20 15 | 16 | Note that we are using the *coordinate-free* data (due to noted inaccuracies in the original coordinates). 17 | 18 | Make sure to download the images first to `data/download/llava-v1.5-instruct/lrv` 19 | => cd data/download/llava-v1.5-instruct/lrv 20 | => [Visual Genome] gdown https://drive.google.com/uc?id=1k9MNV-ImEV9BYEOeLEIb4uGEUZjd3QbM 21 | => `tar -xvf image.tar.gz; mv image lrv-vg; rm image.tar.gz` 22 | => [Chart Data] gdown https://drive.google.com/uc?id=1Dey-undzW2Nl21CYLFSkP_Y4RrfRJkYd 23 | => `unzip chart_image.zip; rm -rf __MACOSX; mv chart_image lrv-chart; rm chart_image.zip` 24 | 25 | Download the raw JSON files to the same directory - `data/download/llava-v1.5-instruct/lrv` 26 | => [LRV Instruct Pt. 1] gdown https://drive.google.com/uc?id=1pWkxE2kqpys1VdwBi99ZXN6-XY5SqhwU 27 | => `filter_cap1.json` 28 | => [LRV Instruct Pt. II] gdown https://drive.google.com/uc?id=1NTxkuRPlvDn7aWaJpK_yb0p5r0cxPLNZ 29 | => `filter_cap_more1.json` 30 | => [Chart Instruct] gdown https://drive.google.com/uc?id=13j2U-ectsYGR92r6J5hPdhT8T5ezItHF 31 | => `chart_release_update.json` 32 | 33 | References: "Mitigating Hallucination in Large Multi-Modal Models via Robust Instruction Tuning" 34 | => Paper: https://arxiv.org/abs/2306.14565 35 | => Github / Data: https://github.com/FuxiaoLiu/LRV-Instruction 36 | """ 37 | import json 38 | import random 39 | from pathlib import Path 40 | 41 | from tqdm import tqdm 42 | 43 | # === Constants === 44 | BASE_DIR = Path("data/download/llava-v1.5-instruct") 45 | LRV_DIR = BASE_DIR / "lrv" 46 | 47 | VG_JSON_FILES, VG_IMG_DIR = [LRV_DIR / "filter_cap1.json", LRV_DIR / "filter_cap_more1.json"], LRV_DIR / "lrv-vg" 48 | CHART_JSON_FILE, CHART_IMG_DIR = LRV_DIR / "chart_release_update.json", LRV_DIR / "lrv-chart" 49 | 50 | # JSON Files for "merged" variants fo the dataset (with `llava_v1_5_mix665k.json` and `llava_v1_5_lvis4v_mix888k.json` 51 | BASE_JSON_FILE = BASE_DIR / "llava_v1_5_mix665k.json" 52 | BASE_LVIS_JSON_FILE = BASE_DIR / "llava_v1_5_lvis4v_mix888k.json" 53 | 54 | MERGED_BASE_LRV_JSON_FILE = BASE_DIR / "llava_v1_5_lrv_mix1008k.json" 55 | MERGED_BASE_LVIS_LRV_JSON_FILE = BASE_DIR / "llava_v1_5_lvis4v_lrv_mix1231k.json" 56 | 57 | 58 | def build_lrv_instruct() -> None: 59 | print("[*] Downloading and Formatting `LRV-Instruct` Dataset!") 60 | 61 | # Set Random Seed 62 | random.seed(7) 63 | 64 | # Open VG JSON Files 65 | vg_examples = [] 66 | for fn in VG_JSON_FILES: 67 | with open(fn, "r") as f: 68 | vg_examples.extend(json.load(f)) 69 | 70 | # Iterate through VG Examples & Verify Image Existence 71 | for example in tqdm(vg_examples, desc="[*] Verifying all VG Images in LRV Instruct"): 72 | image_id = example["image_id"] 73 | assert (VG_IMG_DIR / f"{image_id}.jpg").exists(), f"Missing Image `{image_id}.jpg`" 74 | 75 | # Open Chart JSON File 76 | with open(CHART_JSON_FILE, "r") as f: 77 | chart_examples = json.load(f) 78 | 79 | # Iterate through Chart Examples & Verify Image Existence 80 | for example in tqdm(chart_examples, desc="[*] Verifying all Chart Images in LRV Instruct"): 81 | image_path = example["image_id"] 82 | assert (CHART_IMG_DIR / image_path).exists(), f"Missing Image `{image_path}`" 83 | 84 | # Reformat VG Examples as LLaVa "Chat" Style => List[Entry] where each Entry is a Dictionary: 85 | # => "id": str 86 | # => "image": str -- Relative path from `BASE_DIR` 87 | # => "conversations: List[Turn] where Turn is a Dictionary: 88 | # => {"from": "human", "value": "\n{VG_EXAMPLE['question']}"} 89 | # => {"from": "gpt", "value": "{VG_EXAMPLE['answer']}"} 90 | vg_chat_json = [] 91 | for vg_example in tqdm(vg_examples, desc="[*] Converting all VG Examples to LLaVa Format"): 92 | vg_chat_json.append( 93 | { 94 | "id": vg_example["image_id"], 95 | "image": f"lrv/lrv-vg/{vg_example['image_id']}.jpg", 96 | "conversations": [ 97 | {"from": "human", "value": f"\n{vg_example['question'].strip()}"}, 98 | {"from": "gpt", "value": vg_example["answer"].strip()}, 99 | ], 100 | } 101 | ) 102 | 103 | # Reformat Chart Examples as LLaVa "Chat" Style 104 | chart_chat_json = [] 105 | for chart_example in tqdm(chart_examples, desc="[*] Converting all Chart Examples to LLaVa Format"): 106 | chart_chat_json.append( 107 | { 108 | "id": Path(chart_example["image_id"]).stem, 109 | "image": f"lrv/lrv-chart/{chart_example['image_id']}", 110 | "conversations": [ 111 | {"from": "human", "value": f"\n{chart_example['question'].strip()}"}, 112 | {"from": "gpt", "value": chart_example["answer"].strip()}, 113 | ], 114 | } 115 | ) 116 | 117 | # Merge and Create Full LRV Chat Data =>> Total of 342,799 Examples 118 | lrv_data = vg_chat_json + chart_chat_json 119 | 120 | # Create Stacked Datasets =>> Shuffle for Good Measure! 121 | print("[*] Loading LLaVa v1.5 Data!") 122 | with open(BASE_JSON_FILE, "r") as f: 123 | llava_v15_data = json.load(f) 124 | 125 | # Combine & Shuffle & Write 126 | llava_lrv_data = llava_v15_data + lrv_data 127 | 128 | random.shuffle(llava_lrv_data) 129 | random.shuffle(llava_lrv_data) 130 | random.shuffle(llava_lrv_data) 131 | 132 | with open(MERGED_BASE_LRV_JSON_FILE, "w") as f: 133 | json.dump(llava_lrv_data, f) 134 | 135 | print("[*] Loading LLaVa v1.5 + LVIS-4V Instruct Data!") 136 | with open(BASE_LVIS_JSON_FILE, "r") as f: 137 | llava_v15_lvis_data = json.load(f) 138 | 139 | # Combine & Shuffle & Write 140 | full_data = llava_v15_lvis_data + lrv_data 141 | 142 | random.shuffle(full_data) 143 | random.shuffle(full_data) 144 | random.shuffle(full_data) 145 | 146 | with open(MERGED_BASE_LVIS_LRV_JSON_FILE, "w") as f: 147 | json.dump(full_data, f) 148 | 149 | 150 | if __name__ == "__main__": 151 | build_lrv_instruct() 152 | -------------------------------------------------------------------------------- /prismatic-vlms/scripts/additional-datasets/lvis_instruct_4v.py: -------------------------------------------------------------------------------- 1 | """ 2 | scripts/additional-datasets/lvis_instruct4v.py 3 | 4 | Standalone script for pre-processing the LVIS-Instruct4V (language/chat) data (`lvis_instruct4v_220k.json`). This 5 | dataset is curated from LVIS images (subset of COCO yet again), but chat data is synthesized from GPT4-Vision. 6 | 7 | This script downloads the raw data, merges with the LLaVa v15 data, and performs any other data normalization, saving 8 | the resulting `.json` file(s) to the `data/download/llava-v1.5-instruct/` directory. 9 | 10 | Make sure to download the COCO Val 2017 (LVIS) data to `data/download/llava-v1.5-instruct/coco`: 11 | => cd data/download/llava-v1.5-instruct/coco 12 | => wget http://images.cocodataset.org/zips/val2017.zip 13 | => unzip val2017.zip; rm val2017.zip 14 | 15 | References: "To See is to Believe: Prompting GPT-4V for Better Visual Instruction Tuning" 16 | => Paper: https://arxiv.org/abs/2311.07574 17 | => Github / Data: https://github.com/X2FD/LVIS-INSTRUCT4V || https://huggingface.co/datasets/X2FD/LVIS-Instruct4V 18 | """ 19 | import json 20 | import os 21 | import random 22 | from pathlib import Path 23 | 24 | from tqdm import tqdm 25 | 26 | from prismatic.preprocessing.download import download_with_progress 27 | 28 | # === Constants === 29 | DATA_URL = "https://huggingface.co/datasets/X2FD/LVIS-Instruct4V/resolve/main/lvis_instruct4v_220k.json" 30 | DOWNLOAD_DIR = Path("data/download/llava-v1.5-instruct") 31 | RAW_JSON_FILE = DOWNLOAD_DIR / "lvis_instruct4v_220k.json" 32 | 33 | # JSON Files for "merged" variant of the dataset (with `llava_v1_5_mix665k.json`) 34 | BASE_JSON_FILE = DOWNLOAD_DIR / "llava_v1_5_mix665k.json" 35 | MERGED_JSON_FILE = DOWNLOAD_DIR / "llava_v1_5_lvis4v_mix888k.json" 36 | 37 | 38 | def build_lvis_instruct_4v() -> None: 39 | print("[*] Downloading and Formatting `LVIS-Instruct-4V` Dataset!") 40 | 41 | # Set Random Seed 42 | random.seed(7) 43 | 44 | # Download Dataset JSON 45 | os.makedirs(DOWNLOAD_DIR, exist_ok=True) 46 | if not RAW_JSON_FILE.exists(): 47 | download_with_progress(DATA_URL, DOWNLOAD_DIR) 48 | 49 | # Open JSON File --> verify image existence! 50 | print("[*] Loading LVIS Instruct4V Data!") 51 | with open(RAW_JSON_FILE, "r") as f: 52 | data = json.load(f) 53 | 54 | # Iterate & Verify 55 | for example in tqdm(data, desc="[*] Verifying all Images in LVIS Instruct4V"): 56 | image_path = example["image"] 57 | assert (DOWNLOAD_DIR / image_path).exists(), f"Missing Image `{image_path}`" 58 | 59 | # Create Stacked Dataset =>> Shuffle for Good Measure! 60 | print("[*] Loading LLaVa v1.5 Data!") 61 | with open(BASE_JSON_FILE, "r") as f: 62 | llava_v15_data = json.load(f) 63 | 64 | # Combine & Shuffle & Write 65 | full_data = llava_v15_data + data 66 | 67 | random.shuffle(full_data) 68 | random.shuffle(full_data) 69 | random.shuffle(full_data) 70 | 71 | with open(MERGED_JSON_FILE, "w") as f: 72 | json.dump(full_data, f) 73 | 74 | 75 | if __name__ == "__main__": 76 | build_lvis_instruct_4v() 77 | -------------------------------------------------------------------------------- /prismatic-vlms/scripts/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate.py 3 | 4 | Simple CLI script to interactively test generating from a pretrained VLM; provides a minimal REPL for specify image 5 | URLs, prompts, and language generation parameters. 6 | 7 | Run with: python scripts/generate.py --model_path 8 | """ 9 | import os 10 | from dataclasses import dataclass 11 | from pathlib import Path 12 | from typing import Union 13 | 14 | import draccus 15 | import requests 16 | import torch 17 | from PIL import Image 18 | 19 | from prismatic import load 20 | from prismatic.overwatch import initialize_overwatch 21 | 22 | # Initialize Overwatch =>> Wraps `logging.Logger` 23 | overwatch = initialize_overwatch(__name__) 24 | 25 | 26 | # Default Image URL (Beignets) 27 | DEFAULT_IMAGE_URL = ( 28 | "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" 29 | ) 30 | 31 | 32 | @dataclass 33 | class GenerateConfig: 34 | # fmt: off 35 | model_path: Union[str, Path] = ( # Path to Pretrained VLM (on disk or HF Hub) 36 | "prism-dinosiglip+7b" 37 | ) 38 | 39 | # HF Hub Credentials (required for Gated Models like LLaMa-2) 40 | hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token 41 | 42 | # Default Generation Parameters =>> subscribes to HuggingFace's GenerateMixIn API 43 | do_sample: bool = False 44 | temperature: float = 1.0 45 | max_new_tokens: int = 512 46 | min_length: int = 1 47 | 48 | # fmt: on 49 | 50 | 51 | @draccus.wrap() 52 | def generate(cfg: GenerateConfig) -> None: 53 | overwatch.info(f"Initializing Generation Playground with Prismatic Model `{cfg.model_path}`") 54 | hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] 55 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 56 | 57 | # Load the pretrained VLM --> uses default `load()` function 58 | vlm = load(cfg.model_path, hf_token=hf_token) 59 | vlm.to(device, dtype=torch.bfloat16) 60 | 61 | # Initial Setup 62 | image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") 63 | prompt_builder = vlm.get_prompt_builder() 64 | system_prompt = prompt_builder.system_prompt 65 | 66 | # REPL Welcome Message 67 | print( 68 | "[*] Dropping into Prismatic VLM REPL with Default Generation Setup => Initial Conditions:\n" 69 | f" => Prompt Template:\n\n{prompt_builder.get_potential_prompt('')}\n\n" 70 | f" => Default Image URL: `{DEFAULT_IMAGE_URL}`\n===\n" 71 | ) 72 | 73 | # REPL 74 | repl_prompt = ( 75 | "|=>> Enter (i)mage to fetch image from URL, (p)rompt to update prompt template, (q)uit to exit, or any other" 76 | " key to enter input questions: " 77 | ) 78 | while True: 79 | user_input = input(repl_prompt) 80 | 81 | if user_input.lower().startswith("q"): 82 | print("\n|=>> Received (q)uit signal => Exiting...") 83 | return 84 | 85 | elif user_input.lower().startswith("i"): 86 | # Note => a new image starts a _new_ conversation (for now) 87 | url = input("\n|=>> Enter Image URL: ") 88 | image = Image.open(requests.get(url, stream=True).raw).convert("RGB") 89 | prompt_builder = vlm.get_prompt_builder(system_prompt=system_prompt) 90 | 91 | elif user_input.lower().startswith("p"): 92 | if system_prompt is None: 93 | print("\n|=>> Model does not support `system_prompt`!") 94 | continue 95 | 96 | # Note => a new system prompt starts a _new_ conversation 97 | system_prompt = input("\n|=>> Enter New System Prompt: ") 98 | prompt_builder = vlm.get_prompt_builder(system_prompt=system_prompt) 99 | print( 100 | "\n[*] Set New System Prompt:\n" 101 | f" => Prompt Template:\n{prompt_builder.get_potential_prompt('')}\n\n" 102 | ) 103 | 104 | else: 105 | print("\n[*] Entering Chat Session - CTRL-C to start afresh!\n===\n") 106 | try: 107 | while True: 108 | message = input("|=>> Enter Prompt: ") 109 | 110 | # Build Prompt 111 | prompt_builder.add_turn(role="human", message=message) 112 | prompt_text = prompt_builder.get_prompt() 113 | 114 | # Generate from the VLM 115 | generated_text = vlm.generate( 116 | image, 117 | prompt_text, 118 | do_sample=cfg.do_sample, 119 | temperature=cfg.temperature, 120 | max_new_tokens=cfg.max_new_tokens, 121 | min_length=cfg.min_length, 122 | ) 123 | prompt_builder.add_turn(role="gpt", message=generated_text) 124 | print(f"\t|=>> VLM Response >>> {generated_text}\n") 125 | 126 | except KeyboardInterrupt: 127 | print("\n===\n") 128 | continue 129 | 130 | 131 | if __name__ == "__main__": 132 | generate() 133 | -------------------------------------------------------------------------------- /prismatic-vlms/scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | preprocess.py 3 | 4 | Core script for automatically downloading raw VLM pretraining datasets. Supports downloading the following datasets: 5 | - LLaVA v1.5 Datasets (for both training stages) [`llava-laion-cc-sbu-558k`, `llava-v1.5-instruct`] 6 | - Stage 1 :: Projection Matrix Alignment between Vision Encoder & Pretrained LLM on CC-3M-595K (Custom) 7 | - Stage 2 :: Projection & LLM Finetuning on LLaVa v1.5 Instruct (including various vision-language train sets) 8 | 9 | By default, runs download & extraction automatically. 10 | 11 | Run with: `python scripts/preprocess.py --dataset_id ` 12 | """ 13 | from dataclasses import dataclass 14 | from pathlib import Path 15 | 16 | import draccus 17 | 18 | from prismatic.overwatch import initialize_overwatch 19 | from prismatic.preprocessing import convert_to_jpg, download_extract 20 | 21 | # Initialize Overwatch =>> Wraps `logging.Logger` 22 | overwatch = initialize_overwatch(__name__) 23 | 24 | 25 | @dataclass 26 | class PreprocessConfig: 27 | # fmt: off 28 | dataset_id: str = "llava-v1.5-instruct" # Unique identifier for dataset to process (see above) 29 | root_dir: Path = Path("data") # Path to root directory for storing datasets 30 | 31 | # fmt: on 32 | 33 | 34 | @draccus.wrap() 35 | def preprocess(cfg: PreprocessConfig) -> None: 36 | overwatch.info(f"Downloading & Extracting `{cfg.dataset_id}` to `{cfg.root_dir / 'download'}") 37 | download_extract(cfg.dataset_id, root_dir=cfg.root_dir) 38 | 39 | # Special Handling for OCR VQA Images (for `llava-v1.5-instruct`) --> convert GIFs/PNGs to JPG 40 | if cfg.dataset_id == "llava-v1.5-instruct": 41 | convert_to_jpg(cfg.root_dir / "download" / cfg.dataset_id / "ocr_vqa" / "images") 42 | 43 | 44 | if __name__ == "__main__": 45 | preprocess() 46 | -------------------------------------------------------------------------------- /prismatic-vlms/train.sh: -------------------------------------------------------------------------------- 1 | CKPTID=$1 2 | STAGE=$2 3 | BSZ=$3 4 | PER_GPU_BSZ=$4 5 | 6 | torchrun --nproc_per_node 8 scripts/pretrain.py \ 7 | --stage ${STAGE} \ 8 | --model.type "one-stage+7b" \ 9 | --model.model_id qwen2.5-1.5b-instruct-continue-training-${CKPTID} \ 10 | --model.arch_specifier "no-align+avgpool" \ 11 | --model.vision_backbone_id "siglip-vit-so400m-384px" \ 12 | --model.image_resize_strategy "resize-naive" \ 13 | --model.llm_backbone_id "qwen2.5-1.5b-instruct" \ 14 | --model.pretrain_global_batch_size ${BSZ} \ 15 | --model.pretrain_per_device_batch_size ${PER_GPU_BSZ} \ 16 | --model.pretrain_epochs 1 \ 17 | --mount_path Qwen \ 18 | --run_root_dir checkpoints/ \ 19 | --dataset.type "pretrain" \ 20 | --dataset.dataset_root_dir data/datacomp/datacomp_hq_single_pkl_pil:data/ccs/ccs_single_pkl_pil/:data/laion/laion_single_pkl_pil/ -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import torch 3 | from PIL import Image 4 | from prismatic import load 5 | 6 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 7 | 8 | # Load a pretrained VLM (either local path, or ID to auto-download from the HF Hub) 9 | vlm = load("Open-Qwen2VL") 10 | vlm.to(device, dtype=torch.bfloat16) 11 | 12 | # Download an image and specify a prompt 13 | image_url = "https://huggingface.co/adept/fuyu-8b/resolve/main/bus.png" 14 | # image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") 15 | image = [vlm.vision_backbone.image_transform(Image.open(requests.get(image_url, stream=True).raw).convert("RGB")).unsqueeze(0)] 16 | user_prompt = '' + '\n' + "Describe the image." 17 | 18 | # Generate! 19 | generated_text = vlm.generate_batch( 20 | image, 21 | [user_prompt], 22 | do_sample=False, 23 | max_new_tokens=512, 24 | min_length=1, 25 | ) 26 | print(generated_text[0]) -------------------------------------------------------------------------------- /vlm-evaluation/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Logs 105 | serve_images/ 106 | *conv.json 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Ruff 136 | .ruff_cache/ 137 | 138 | # IDE caches 139 | .idea/ 140 | .vscode/ 141 | 142 | # Mac OS 143 | .DS_Store 144 | 145 | # Tokens 146 | # .hf_token 147 | 148 | # Scratch & Caches 149 | __scratch/ 150 | scratch/ 151 | cache/ 152 | results/ -------------------------------------------------------------------------------- /vlm-evaluation/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | exclude: ".git" 4 | 5 | repos: 6 | - repo: https://github.com/charliermarsh/ruff-pre-commit 7 | rev: v0.0.252 8 | hooks: 9 | - id: ruff 10 | args: [ --fix, --exit-non-zero-on-fix ] 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 23.1.0 14 | hooks: 15 | - id: black 16 | 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.4.0 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-ast 22 | - id: check-case-conflict 23 | - id: check-merge-conflict 24 | - id: check-toml 25 | - id: check-yaml 26 | - id: end-of-file-fixer 27 | - id: trailing-whitespace 28 | -------------------------------------------------------------------------------- /vlm-evaluation/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-present, Siddharth Karamcheti, Suraj Nair, Ashwin Balakrishna 4 | and Toyota Research Institute. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /vlm-evaluation/eval.sh: -------------------------------------------------------------------------------- 1 | # task=$2 2 | DATAPATH=./ 3 | 4 | # supported tasks: gqa vqa-v2 vizwiz okvqa ai2d text-vqa pope mmmu mmbench seedbench-image mathvista mmstar mantis mmlu 5 | 6 | for task in ai2d text-vqa pope mmmu mmbench seedbench mmstar mathvista ; do 7 | subset="full" 8 | rm ${DATAPATH}/vlm-evaluation/datasets/${task}/* 9 | python scripts/datasets/prepare.py --dataset_family ${task} --root_dir ${DATAPATH}/vlm-evaluation/ --shots 0 10 | 11 | rm -r results/${task}/${task}-${subset}/ 12 | 13 | accelerate launch --main_process_port 29511 --num_processes=8 scripts/evaluate.py --model_dir $1 --dataset.type ${task}-${subset} --dataset.root_dir ${DATAPATH}/vlm-evaluation/ --results_dir ./results --model_id prism-siglip-sft 14 | 15 | python scripts/score.py --model_id prism-siglip-sft --dataset.type ${task}-${subset} --dataset.root_dir ${DATAPATH}/vlm-evaluation/ --results_dir ./results 16 | done 17 | -------------------------------------------------------------------------------- /vlm-evaluation/images/03-evaluation-suite-med-res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/images/03-evaluation-suite-med-res.png -------------------------------------------------------------------------------- /vlm-evaluation/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vlm_eval" 7 | authors = [ 8 | {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"}, 9 | {name = "Suraj Nair", email="suraj.nair@tri.global"}, 10 | {name = "Ashwin Balakrishna", email="ashwin.balakrishna@tri.global"} 11 | ] 12 | description = "VLM Eval: Benchmark for VLMs, spanning text generation tasks from VQA to Captioning" 13 | version = "0.0.1" 14 | readme = "README.md" 15 | requires-python = ">=3.8" 16 | keywords = ["machine learning"] 17 | license = {file = "LICENSE"} 18 | classifiers = [ 19 | "Development Status :: 3 - Alpha", 20 | "Intended Audience :: Developers", 21 | "Intended Audience :: Education", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ] 32 | dependencies = [ 33 | "accelerate", 34 | "ascii_magic", 35 | "jsonlines", 36 | "jinja2==3.0.3", 37 | "llava @ git+https://github.com/Victorwz/LLaVA-Unified", 38 | "mosaicml-streaming", 39 | "openai", 40 | "pycocotools", 41 | "rich", 42 | "scikit-image", 43 | # "salesforce-lavis @ git+https://github.com/siddk/LAVIS", 44 | # "webdataset", 45 | "pymongo", 46 | "spacy", 47 | "timm==0.9.10", 48 | "numpy==1.26.4", 49 | ] 50 | 51 | [project.optional-dependencies] 52 | dev = [ 53 | "black", 54 | "gpustat", 55 | "ipython", 56 | "pre-commit", 57 | "ruff", 58 | ] 59 | 60 | [project.urls] 61 | homepage = "https://github.com/TRI-ML/vlm-evaluation" 62 | repository = "https://github.com/TRI-ML/vlm-evaluation" 63 | documentation = "https://github.com/TRI-ML/vlm-evaluation" 64 | 65 | [tool.setuptools.packages.find] 66 | where = ["."] 67 | exclude = ["cache"] 68 | 69 | [tool.black] 70 | line-length = 121 71 | target-version = ["py38", "py39", "py310"] 72 | preview = true 73 | 74 | [tool.ruff] 75 | line-length = 121 76 | target-version = "py38" 77 | select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"] 78 | ignore = ["B008", "F722"] 79 | 80 | [tool.ruff.per-file-ignores] 81 | "__init__.py" = ["E402", "F401"] 82 | -------------------------------------------------------------------------------- /vlm-evaluation/scripts/datasets/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | prepare.py 3 | 4 | Entry point for dataset downloading & preparation -- handles all aspects of the raw data acquisition, extraction, and 5 | verification process, writing both WebDataset and Mosaic Streaming (MDS) versions of the data. 6 | """ 7 | import os 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | from typing import Tuple, Union 11 | 12 | import draccus 13 | 14 | from vlm_eval.overwatch import initialize_overwatch 15 | from vlm_eval.tasks import build_index_datasets, download_extract 16 | 17 | # Initialize Overwatch =>> Wraps `logging.Logger` 18 | overwatch = initialize_overwatch(__name__) 19 | 20 | 21 | @dataclass 22 | class DatasetPreparationConfig: 23 | # fmt: off 24 | dataset_family: str = "ai2d" # Dataset family to prepare 25 | 26 | # Processing Parameters 27 | create_slim_dataset: bool = True # Whether to create "slim" (minified) dataset(s) 28 | slim_dataset_sizes: Tuple[int, ...] = ( # Number of examples for the slim dataset(s) 29 | 1024, 4096, 30 | ) 31 | export_formats: Tuple[str, ...] = ( # Formats for export (always writes a "Map" Dataset) 32 | "webdataset", 33 | "mosaic-streaming", 34 | ) 35 | 36 | # Format-Specific Parameters 37 | max_shard_size_bytes: int = 64000000 # Maximum size for a shard in bytes (default: 64 MB) 38 | wds_examples_per_shard: int = 1024 # [WebDataset] Number of examples per `tar` shard 39 | mds_hashes: Tuple[str, str] = ("sha1", "xxh64") # [Mosaic] Pair of (crypto, non-crypto) hash functions 40 | 41 | # Path Parameters 42 | root_dir: Path = Path( # Path to root directory for storing datasets 43 | # "datasets/vlm-evaluation" 44 | "/mnt/fsx/skaramcheti/datasets/vlm-evaluation" 45 | ) 46 | 47 | # HF Hub Credentials (for LLaMa-2) 48 | hf_token: Union[str, Path] = Path(".hf_token") # Env Variable or Path to HF Token 49 | 50 | # Randomness 51 | seed: int = 21 # Random Seed (for slim datasets, augmentations) 52 | 53 | # Number of demonstrations 54 | shots: int = 8 55 | # fmt: on 56 | 57 | 58 | @draccus.wrap() 59 | def prepare(cfg: DatasetPreparationConfig) -> None: 60 | overwatch.info(f"Downloading and Preparing VLM Evaluation Dataset `{cfg.dataset_family}`") 61 | 62 | # Phase 1 :: Download & Extract Raw Data to `cfg.data_dir` / cfg.dataset_id / "download" 63 | overwatch.info(f"Phase 1 =>> Downloading & Extracting `{cfg.dataset_family}` to {cfg.root_dir / 'download'}") 64 | hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] 65 | download_extract(cfg.dataset_family, cfg.root_dir, hf_token) 66 | 67 | # Phase 2 :: Assemble Index Dataset(s) (always builds metadata from local disk, then used to export other formats) 68 | overwatch.info(f"Phase 2 =>> Building Index Dataset(s) for `{cfg.dataset_family}` at {cfg.root_dir / 'datasets'}") 69 | index_datasets = build_index_datasets( 70 | cfg.dataset_family, 71 | cfg.root_dir, 72 | slim_dataset_sizes=cfg.slim_dataset_sizes if cfg.create_slim_dataset else None, 73 | seed=cfg.seed, 74 | shots=cfg.shots, 75 | ) 76 | 77 | # Phase 3 :: Build Streaming / Iterable Datasets in the desired format(s) 78 | return index_datasets 79 | 80 | 81 | if __name__ == "__main__": 82 | prepare() 83 | -------------------------------------------------------------------------------- /vlm-evaluation/scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | evaluate.py 3 | 4 | Entry point for all VLM-Evaluation evaluations; specify model and dataset, get results. 5 | 6 | Run with `accelerate` from repository root (for naive parallelization): 7 | =>> [Single-GPU] CUDA_VISIBLE_DEVICES={0-7} accelerate launch --num_processes=1 scripts/evaluate.py < args > 8 | =>> [Multi-GPU] accelerate launch --num_processes={>1} scripts/evaluate.py < args > 9 | """ 10 | import os 11 | from dataclasses import dataclass, field 12 | from pathlib import Path 13 | from typing import Union, Optional 14 | 15 | import draccus 16 | from accelerate.utils import set_seed 17 | 18 | from vlm_eval.conf import DatasetConfig, DatasetRegistry 19 | from vlm_eval.models import load_vlm 20 | from vlm_eval.overwatch import initialize_overwatch 21 | from vlm_eval.tasks import get_task_runner 22 | 23 | # Sane Defaults 24 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 25 | 26 | 27 | # Initialize Overwatch =>> Wraps `logging.Logger` and `accelerate.PartialState` 28 | overwatch = initialize_overwatch(__name__) 29 | 30 | 31 | @dataclass 32 | class EvaluationConfig: 33 | # fmt: off 34 | 35 | # DatasetConfig from `vlm_eval/conf/datasets.py`; override with --dataset.type `DatasetRegistry..dataset_id` 36 | dataset: DatasetConfig = field( 37 | default_factory=DatasetConfig.get_choice_class(DatasetRegistry.AI2D_FULL.dataset_id) 38 | ) 39 | 40 | # === Model Parameters =>> Prismatic === 41 | model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... > 42 | model_id: Optional[str] = ( # Model ID to load and run (instance of `model_family`) 43 | "prism-clip+7b" 44 | ) 45 | model_dir: Optional[Path] = None # Path to model checkpoint to load --> should be self-contained 46 | 47 | # === Model Parameters =>> Official LLaVa === 48 | # model_family: str = "llava-v15" 49 | # model_id: str = "llava-v1.5-7b" 50 | # model_dir: Path = "liuhaotian/llava-v1.5-7b" 51 | 52 | # === Model Parameters =>> Official InstructBLIP === 53 | # model_family: str = "instruct-blip" 54 | # model_id: str = "instructblip-vicuna-7b" 55 | # model_dir: Path = "Salesforce/instructblip-vicuna-7b" 56 | 57 | # Inference Parameters 58 | device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs! 59 | num_workers: int = 2 # Number of Dataloader Workers (on each process) 60 | 61 | # Artifact Parameters 62 | results_dir: Path = Path( # Path to results directory (writing predicted output, metrics) 63 | "results" 64 | ) 65 | 66 | # HF Hub Credentials (for LLaMa-2) 67 | hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token 68 | 69 | # Randomness 70 | seed: int = 21 # Random Seed (for reproducibility) 71 | 72 | def __post_init__(self) -> None: 73 | self.run_dir = self.model_dir 74 | 75 | # fmt: on 76 | 77 | 78 | @draccus.wrap() 79 | def evaluate(cfg: EvaluationConfig) -> None: 80 | overwatch.info(f"Starting Evaluation for Dataset `{cfg.dataset.dataset_id}` w/ Model `{cfg.model_id}`") 81 | set_seed(cfg.seed) 82 | 83 | # Short-Circuit (if results/metrics already exist) 84 | task_results_dir = cfg.results_dir / cfg.dataset.dataset_family / cfg.dataset.dataset_id / cfg.model_id 85 | if (task_results_dir / "metrics.json").exists(): 86 | overwatch.info(f"Metrics for `{cfg.dataset.dataset_id}` w/ `{cfg.model_id}` exist =>> exiting!") 87 | return 88 | 89 | # Build the VLM --> Download/Load Pretrained Model from Checkpoint 90 | overwatch.info("Initializing VLM =>> Bundling Models, Image Processors, and Tokenizer") 91 | hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] 92 | vlm = load_vlm(cfg.model_family, cfg.model_id, cfg.run_dir, hf_token=hf_token, ocr=cfg.dataset.ocr) 93 | 94 | # Create Task Runner 95 | overwatch.info(f"Building Evaluation Runner for Dataset `{cfg.dataset.dataset_id}`") 96 | task_runner = get_task_runner( 97 | cfg.dataset.dataset_family, 98 | cfg.dataset.root_dir, 99 | cfg.dataset.index_file, 100 | task_results_dir, 101 | cfg.model_id, 102 | prompt_fn=vlm.get_prompt_fn(cfg.dataset.dataset_family), 103 | image_processor=vlm.image_processor, 104 | prompt_builder=vlm.get_prompt_builder, 105 | ) 106 | 107 | # Run Evaluation 108 | overwatch.info("Starting (Distributed) Evaluation Loop") 109 | task_runner.evaluate(vlm, cfg.device_batch_size, cfg.num_workers) 110 | 111 | 112 | if __name__ == "__main__": 113 | evaluate() 114 | -------------------------------------------------------------------------------- /vlm-evaluation/scripts/score.py: -------------------------------------------------------------------------------- 1 | """ 2 | score.py 3 | 4 | Aggregation & "official scoring" for all VLM-Bench evaluations; to be run after dumping generations via `evaluate.py`. 5 | 6 | Where possible, uses the *official evaluation script* to evaluate the performance of a model on a validation/testdev 7 | split --> as an example, using the official GQA `eval.py` or VQAv2 leaderboard script for evaluating VQA performance. 8 | 9 | Run from the repository root: 10 | => python scripts/score.py < args > 11 | """ 12 | import json 13 | from dataclasses import dataclass, field 14 | from pathlib import Path 15 | from typing import Dict, List, Optional, Union 16 | 17 | import draccus 18 | import yaml 19 | 20 | from vlm_eval.conf import DatasetConfig, DatasetRegistry 21 | from vlm_eval.overwatch import initialize_overwatch 22 | from vlm_eval.tasks import get_scorer 23 | 24 | # === By default - official scoring scripts only support the `-full` dataset variants; add overrides below === 25 | VALID_DATASET_ID_OVERRIDES = {} 26 | 27 | 28 | # Initialize Overwatch =>> Wraps `logging.Logger` and `accelerate.PartialState` 29 | overwatch = initialize_overwatch(__name__) 30 | 31 | @dataclass 32 | class ScoreConfig: 33 | # fmt: off 34 | 35 | # DatasetConfig from `vlm_eval/conf/datasets.py`; override with --dataset.type `DatasetRegistry..dataset_id` 36 | dataset: DatasetConfig = field( 37 | default_factory=DatasetConfig.get_choice_class(DatasetRegistry.AI2D_FULL.dataset_id) 38 | ) 39 | 40 | # === Model Parameters =>> Prismatic === 41 | model_id: str = "prism-clip+7b" # Model ID to load and run (instance of `model_family`) 42 | 43 | # === Model Parameters =>> Official LLaVa === 44 | # model_id: str = "llava-v1.5-7b" 45 | 46 | # === Model Parameters =>> Official InstructBLIP === 47 | # model_id: str = "instructblip-vicuna-7b" 48 | 49 | config_yaml: Optional[Path] = None 50 | 51 | # Artifact Parameters 52 | results_dir: Path = Path( # Path to results directory (writing predicted output, metrics) 53 | "results" 54 | ) 55 | 56 | # fmt: on 57 | 58 | 59 | @draccus.wrap() 60 | def score(cfg: ScoreConfig) -> None: 61 | overwatch.info(f"Starting Official Scoring for Dataset `{cfg.dataset.dataset_id}` => Model `{cfg.model_id}`") 62 | 63 | # Short-Circuit (if results/metrics already exist) 64 | dataset_family, dataset_id = cfg.dataset.dataset_family, cfg.dataset.dataset_id 65 | task_results_dir = cfg.results_dir / cfg.dataset.dataset_family / cfg.dataset.dataset_id / cfg.model_id 66 | if (metrics_json := task_results_dir / "metrics.json").exists(): 67 | overwatch.info(f"Metrics JSON already exists at `{metrics_json}` =>> Exiting!") 68 | with open(metrics_json, "r") as f: 69 | metrics = json.load(f) 70 | model, dataset, split, summary, experiment_tags = ( 71 | metrics["model"], 72 | metrics["dataset"], 73 | cfg.dataset.split, 74 | metrics["summary"], 75 | metrics["experiment_tags"], 76 | ) 77 | accuracy_keys = [k for k in metrics["summary"].keys() if (k.startswith("accuracy__") or k == "accuracy")] 78 | if len(accuracy_keys) == 1: 79 | result_string = ( 80 | f"Results for Model `{model}` on {dataset} (Split = {split})\n" 81 | f" => Accuracy (Official): {summary['accuracy']:.3f}" 82 | ) 83 | else: 84 | dataset_names = [k.split("__")[1] for k in accuracy_keys] 85 | result_string = ( 86 | f"Results for Model `{model}` on {dataset} ({'/'.join(dataset_names)}) (Split = {split})\n" 87 | ) 88 | for d in dataset_names: 89 | result_string += f" => {d} Accuracy (Official): {summary[f'accuracy__{d}']:.3f}\n" 90 | 91 | # Log to Console 92 | overwatch.info(result_string.strip()) 93 | 94 | return 95 | 96 | # Merge per-Rank Results & Assert on Expected Length 97 | full_results = {} 98 | for rank_json in task_results_dir.glob("results+rank*.json"): 99 | with open(rank_json, "r") as f: 100 | full_results.update(json.load(f)) 101 | 102 | # Validate on Expected # of Examples 103 | assert ( 104 | len(full_results) == cfg.dataset.expected_examples 105 | ), f"Expected {cfg.dataset.expected_examples} model outputs, only found {len(full_results)}!" 106 | 107 | # Per-Family Dataset Handling 108 | root_dir = cfg.dataset.root_dir 109 | scorer = get_scorer( 110 | dataset_family, 111 | dataset_id, 112 | task_results_dir, 113 | full_results, 114 | annotations_file=root_dir / cfg.dataset.annotations_file, 115 | questions_file=root_dir / cfg.dataset.questions_file if cfg.dataset.questions_file is not None else None, 116 | split=cfg.dataset.split, 117 | ) 118 | summary_scores = scorer.score(cfg.model_id) 119 | 120 | # Open Model Config =>> `config.yaml` 121 | if cfg.config_yaml is not None: 122 | with open(cfg.config_yaml, "r") as f: 123 | full_cfg = yaml.safe_load(f) 124 | 125 | # Extract Experiment "Tag" Parameters =>> for Leaderboard Display 126 | experiment_tags = {k: full_cfg["model"][k] for k in ["experiment_tag", "config_line_no", "model_split"]} 127 | else: 128 | # Experiment Tags for "Official Models" don't make sense; set to empty 129 | experiment_tags = {} 130 | 131 | # Finalize Metrics & Write to Disk 132 | metrics = { 133 | "dataset": cfg.dataset.dataset_id, 134 | "n_examples": cfg.dataset.expected_examples, 135 | "model": cfg.model_id, 136 | "experiment_tags": experiment_tags, 137 | "summary": summary_scores, 138 | "examples": full_results, 139 | } 140 | with open(metrics_json, "w") as f: 141 | json.dump(metrics, f, indent=2) 142 | 143 | 144 | if __name__ == "__main__": 145 | score() 146 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetConfig, DatasetRegistry 2 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from vlm_eval.util.interfaces import VLM 5 | 6 | from .instructblip import InstructBLIP 7 | from .llava import LLaVa 8 | from .prismatic import PrismaticVLM 9 | 10 | # === Initializer Dispatch by Family === 11 | FAMILY2INITIALIZER = {"instruct-blip": InstructBLIP, "llava-v15": LLaVa, "prismatic": PrismaticVLM} 12 | 13 | 14 | def load_vlm( 15 | model_family: str, 16 | model_id: str, 17 | run_dir: Path, 18 | hf_token: Optional[str] = None, 19 | ocr: Optional[bool] = False, 20 | load_precision: str = "bf16", 21 | max_length=128, 22 | temperature=1.0, 23 | ) -> VLM: 24 | assert model_family in FAMILY2INITIALIZER, f"Model family `{model_family}` not supported!" 25 | return FAMILY2INITIALIZER[model_family]( 26 | model_family=model_family, 27 | model_id=model_id, 28 | run_dir=run_dir, 29 | hf_token=hf_token, 30 | load_precision=load_precision, 31 | max_length=max_length, 32 | temperature=temperature, 33 | ocr=ocr, 34 | ) 35 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/overwatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .overwatch import initialize_overwatch 2 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/overwatch/overwatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | overwatch.py 3 | 4 | Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. 5 | """ 6 | import logging 7 | import logging.config 8 | import os 9 | from logging import LoggerAdapter 10 | from typing import Union 11 | 12 | # Overwatch Default Format String 13 | RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" 14 | 15 | # Set Logging Configuration 16 | LOG_CONFIG = { 17 | "version": 1, 18 | "disable_existing_loggers": False, 19 | "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, 20 | "handlers": { 21 | "console": { 22 | "class": "rich.logging.RichHandler", 23 | "formatter": "simple-console", 24 | "markup": True, 25 | "rich_tracebacks": True, 26 | "show_level": True, 27 | "show_path": True, 28 | "show_time": True, 29 | } 30 | }, 31 | "root": {"level": "INFO", "handlers": ["console"]}, 32 | } 33 | logging.config.dictConfig(LOG_CONFIG) 34 | 35 | 36 | # === Custom Contextual Logging Logic === 37 | class ContextAdapter(LoggerAdapter): 38 | CTX_PREFIXES = {0: "[*] "} | {idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]} 39 | 40 | def process(self, msg, kwargs): 41 | ctx_level = kwargs.pop("ctx_level", 0) 42 | return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs 43 | 44 | 45 | class DistributedOverwatch: 46 | def __init__(self, name: str) -> None: 47 | """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" 48 | from accelerate import PartialState 49 | 50 | # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` 51 | # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! 52 | self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name)), PartialState() 53 | 54 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 55 | self.debug = self.logger.debug 56 | self.info = self.logger.info 57 | self.warning = self.logger.warning 58 | self.error = self.logger.error 59 | self.critical = self.logger.critical 60 | 61 | # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! 62 | self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) 63 | 64 | 65 | class PureOverwatch: 66 | def __init__(self, name: str) -> None: 67 | """Initializer for an Overwatch object that just wraps logging.""" 68 | self.logger = ContextAdapter(logging.getLogger(name)) 69 | 70 | # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) 71 | self.debug = self.logger.debug 72 | self.info = self.logger.info 73 | self.warning = self.logger.warning 74 | self.error = self.logger.error 75 | self.critical = self.logger.critical 76 | 77 | # Logging Defaults =>> INFO 78 | self.logger.setLevel(logging.INFO) 79 | 80 | 81 | def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: 82 | return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) > 1 else PureOverwatch(name) 83 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/serve/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | # Arrange keys in display priority order (high --> low) 5 | MODEL_ID_TO_NAME = OrderedDict( 6 | [ 7 | ( 8 | "prism-dinosiglip+13b", 9 | "Prism 13B", 10 | ), 11 | ( 12 | "prism-dinosiglip+7b", 13 | "Prism 7B", 14 | ), 15 | ( 16 | "prism-dinosiglip-controlled+13b", 17 | "Prism 13B (Controlled)", 18 | ), 19 | ( 20 | "prism-dinosiglip-controlled+7b", 21 | "Prism 7B (Controlled)", 22 | ), 23 | ("llava-v1.5-13b", "LLaVA 1.5 13B"), 24 | ("llava-v1.5-7b", "LLaVA 1.5 7B"), 25 | ("instructblip-vicuna-7b", "InstructBLIP 7B"), 26 | ] 27 | ) 28 | 29 | INTERACTION_MODES_MAP = OrderedDict( 30 | [ 31 | ("Chat", "chat"), 32 | ("Captioning", "captioning"), 33 | ("Bounding Box Prediction", "bbox_pred"), 34 | ("Visual Question Answering", "vqa"), 35 | ("True/False Visual Question Answering", "true_false"), 36 | ] 37 | ) 38 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/serve/examples/cows_in_pasture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/serve/examples/cows_in_pasture.png -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/serve/examples/monkey_knives.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/serve/examples/monkey_knives.png -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .builders import build_index_datasets 2 | from .download import download_extract 3 | from .harnesses import get_scorer, get_task_runner 4 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/tasks/builders.py: -------------------------------------------------------------------------------- 1 | """ 2 | builders.py 3 | 4 | Utility functions for writing Map and Iterable (WebDataset, Mosaic Streaming) variants of various evaluation datasets. 5 | """ 6 | from pathlib import Path 7 | from typing import Callable, Dict, List, Optional, Tuple 8 | 9 | from torch.utils.data import Dataset 10 | 11 | from vlm_eval.overwatch import initialize_overwatch 12 | from vlm_eval.tasks.harnesses.gqa import GQAIndexDataset, build_gqa_indices 13 | from vlm_eval.tasks.harnesses.okvqa import OKVQAIndexDataset, build_okvqa_indices 14 | from vlm_eval.tasks.harnesses.ocidref import OCIDRefIndexDataset, build_ocidref_indices 15 | from vlm_eval.tasks.harnesses.pope import PopeIndexDataset, build_pope_indices 16 | from vlm_eval.tasks.harnesses.refcoco import RefCOCOIndexDataset, build_refcoco_indices 17 | from vlm_eval.tasks.harnesses.tallyqa import TallyQAIndexDataset, build_tallyqa_indices 18 | from vlm_eval.tasks.harnesses.textvqa import TextVQAIndexDataset, build_textvqa_indices 19 | from vlm_eval.tasks.harnesses.vizwiz import VizWizIndexDataset, build_vizwiz_indices 20 | from vlm_eval.tasks.harnesses.vqav2 import VQAv2IndexDataset, build_vqav2_indices 21 | from vlm_eval.tasks.harnesses.vsr import VSRIndexDataset, build_vsr_indices 22 | from vlm_eval.tasks.harnesses.ai2d import AI2DIndexDataset, build_ai2d_indices 23 | from vlm_eval.tasks.harnesses.mmmu import MMMUIndexDataset, build_mmmu_indices 24 | from vlm_eval.tasks.harnesses.mathvista import MathVistaIndexDataset, build_mathvista_indices 25 | from vlm_eval.tasks.harnesses.mmbench import MMBenchIndexDataset, build_mmbench_indices 26 | from vlm_eval.tasks.harnesses.seedbench import SEEDBenchIndexDataset, build_seedbench_indices 27 | from vlm_eval.tasks.harnesses.mantis import MantisIndexDataset, build_mantis_indices 28 | from vlm_eval.tasks.harnesses.mmstar import MMStarIndexDataset, build_mmstar_indices 29 | from vlm_eval.tasks.harnesses.mscoco_karpathy import MSCOCOIndexDataset, build_mscoco_karpathy_indices 30 | from vlm_eval.tasks.harnesses.mmlu import MMLUIndexDataset, build_mmlu_indices 31 | 32 | # Initialize Overwatch =>> Wraps `logging.Logger` 33 | overwatch = initialize_overwatch(__name__) 34 | 35 | 36 | # === Define Dispatch Registry for each Task (Dataset Family) === 37 | BUILDER_DISPATCH: Dict[str, Dict[str, Callable]] = { 38 | # fmt: off 39 | 40 | # "Standard" Datasets (from Literature) 41 | "vqa-v2": {"build_indices": build_vqav2_indices, "get_index_datasets": VQAv2IndexDataset}, 42 | "gqa": {"build_indices": build_gqa_indices, "get_index_datasets": GQAIndexDataset}, 43 | "okvqa": {"build_indices": build_okvqa_indices, "get_index_datasets": OKVQAIndexDataset}, 44 | "vizwiz": {"build_indices": build_vizwiz_indices, "get_index_datasets": VizWizIndexDataset}, 45 | "pope": {"build_indices": build_pope_indices, "get_index_datasets": PopeIndexDataset}, 46 | "text-vqa": {"build_indices": build_textvqa_indices, "get_index_datasets": TextVQAIndexDataset}, 47 | "vsr": {"build_indices": build_vsr_indices, "get_index_datasets": VSRIndexDataset}, 48 | "refcoco": {"build_indices": build_refcoco_indices, "get_index_datasets": RefCOCOIndexDataset}, 49 | "ocid-ref": {"build_indices": build_ocidref_indices, "get_index_datasets": OCIDRefIndexDataset}, 50 | "tally-qa": {"build_indices": build_tallyqa_indices, "get_index_datasets": TallyQAIndexDataset}, 51 | "ai2d": {"build_indices": build_ai2d_indices, "get_index_datasets": AI2DIndexDataset}, 52 | "mmmu": {"build_indices": build_mmmu_indices, "get_index_datasets": MMMUIndexDataset}, 53 | "mmbench": {"build_indices": build_mmbench_indices, "get_index_datasets": MMBenchIndexDataset}, 54 | "mathvista": {"build_indices": build_mathvista_indices, "get_index_datasets": MathVistaIndexDataset}, 55 | "seedbench": {"build_indices": build_seedbench_indices, "get_index_datasets": SEEDBenchIndexDataset}, 56 | "mantis": {"build_indices": build_mantis_indices, "get_index_datasets": MantisIndexDataset}, 57 | "mmstar": {"build_indices": build_mmstar_indices, "get_index_datasets": MMStarIndexDataset}, 58 | "mscoco_karpathy": {"build_indices": build_mscoco_karpathy_indices, "get_index_datasets": MSCOCOIndexDataset}, 59 | "mmlu": {"build_indices": build_mmlu_indices, "get_index_datasets": MMLUIndexDataset}, 60 | 61 | # fmt: on 62 | } 63 | 64 | 65 | def build_index_datasets( 66 | dataset_family: str, root_dir: Path, slim_dataset_sizes: Optional[Tuple[int, ...]] = None, seed: int = 21, shots: int = 4 67 | ) -> List[Dataset]: 68 | """ 69 | Given a dataset identifier and optional list of dataset sizes, return a set of PyTorch Map-style Datasets 70 | (building metadata/index files if necessary) that wrap the metadata fields of the given dataset (e.g., returning 71 | image paths and strings instead of processed image tensors or tokens). 72 | 73 | These "index" datasets are to be used for local debugging, and more importantly for synthesizing Iterable, 74 | compressed datasets (WebDataset, Mosaic Streaming). 75 | 76 | To enable this, the underlying assumptions we make for each "index" dataset are as follows: 77 | 1) Entire dataset metadata fits into RAM (to enable quick indexing) 78 | 2) Individual media files (images) exist on local disk, allowing for random access given an image path. 79 | 80 | We define the properties/attributes of each dataset type below (... denotes dataset-specific metadata): 81 | + `dataset_type = vqa`: 82 | -> metadata{-`n_slim in slim_datasets_sizes`}.json 83 | {"question_id" -> {"question_id": str, "question": str, "img_path": Path, "answer": str, ...}} 84 | 85 | :param dataset_family: Dataset family (e.g., "vqa-v2" | "nocaps" | ...) to load from `DATASET_REGISTRY` 86 | :param root_dir: Absolute path to the project's default root directory with task/downloaded data 87 | :param slim_dataset_sizes: List of "slim" dataset sizes to build (each "slim" dataset is a subset of the larger) 88 | :param seed: Random seed for setting initial order of examples in each dataset (some datasets sort questions) 89 | 90 | :return: List of "index" datasets (Pytorch Dataset) of length (1 + len(slim_dataset_sizes)) 91 | """ 92 | overwatch.info(f"Building Index Files for Dataset Family `{dataset_family}`", ctx_level=1) 93 | assert dataset_family in BUILDER_DISPATCH, f"Dataset Family `{dataset_family}` does not have a valid IndexDataset!" 94 | index_files = BUILDER_DISPATCH[dataset_family]["build_indices"](root_dir, slim_dataset_sizes, seed=seed, shots=shots) 95 | 96 | overwatch.info("Assembling Map-Style Datasets from Index Files", ctx_level=1) 97 | index_datasets = [BUILDER_DISPATCH[dataset_family]["get_index_datasets"](root_dir, f) for f in index_files] 98 | 99 | return index_datasets 100 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/tasks/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | download.py 3 | 4 | Utility functions for downloading and extracting various datasets to (local) disk. 5 | """ 6 | import os 7 | import re 8 | import shutil 9 | import subprocess 10 | import tarfile 11 | from pathlib import Path 12 | from zipfile import ZipFile 13 | 14 | import requests 15 | from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn 16 | 17 | from vlm_eval.overwatch import initialize_overwatch 18 | from vlm_eval.tasks.registry import DATASET_REGISTRY 19 | 20 | # Initialize Overwatch =>> Wraps `logging.Logger` 21 | overwatch = initialize_overwatch(__name__) 22 | 23 | 24 | def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024, hf_token=None) -> Path: 25 | """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" 26 | 27 | # Fire an HTTP Request, with `stream = True` => we at least want the Request Headers to validate! 28 | if hf_token is None: 29 | response = requests.get(url, stream=True) 30 | else: 31 | response = requests.get(url, headers={"Authorization": f"Bearer {hf_token}"}, stream=True) 32 | 33 | # Handle Filename Parsing (if not clear from URL) 34 | dest_path = download_dir / Path(url).name.split("?")[0] if "drive.google" not in url else "" 35 | if dest_path == "": 36 | # Parse Content-Headers --> "Content-Disposition" --> filename 37 | filename = re.findall('filename="(.+)"', response.headers["content-disposition"])[0] 38 | dest_path = download_dir / filename 39 | 40 | # Download / Short-Circuit if exists 41 | overwatch.info(f"Downloading {dest_path} from `{url}`", ctx_level=1) 42 | if dest_path.exists(): 43 | return dest_path 44 | 45 | # Download w/ Transfer-Aware Progress 46 | # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py 47 | with Progress( 48 | TextColumn("[bold]{task.description} - {task.fields[fname]}"), 49 | BarColumn(bar_width=None), 50 | "[progress.percentage]{task.percentage:>3.1f}%", 51 | "•", 52 | DownloadColumn(), 53 | "•", 54 | TransferSpeedColumn(), 55 | transient=True, 56 | ) as dl_progress: 57 | total_raw = response.headers.get("content-length", None) 58 | total = int(total_raw) if total_raw is not None else None 59 | dl_tid = dl_progress.add_task("Downloading", fname=dest_path.name, total=total) 60 | with open(dest_path, "wb") as f: 61 | for data in response.iter_content(chunk_size=chunk_size_bytes): 62 | dl_progress.advance(dl_tid, f.write(data)) 63 | 64 | return dest_path 65 | 66 | 67 | def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = True) -> Path: 68 | """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" 69 | ## Semi-hacky naming fix for Ocid-ref because the download file is named differently 70 | if "ocid-ref" in download_dir.as_posix(): 71 | renamed = "/".join(archive_path.as_posix().split("/")[:-1]) + "/OCID-dataset.tar.gz" 72 | os.rename(archive_path.as_posix(), renamed) 73 | archive_path = Path(renamed) 74 | 75 | assert archive_path.suffix in {".gz", ".tar", ".zip"}, f"Invalid compressed archive `{archive_path}`!" 76 | overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) 77 | 78 | # Extract w/ Progress 79 | with Progress( 80 | TextColumn("[bold]{task.description} - {task.fields[aname]}"), 81 | BarColumn(bar_width=None), 82 | "[progress.percentage]{task.percentage:>3.1f}%", 83 | "•", 84 | MofNCompleteColumn(), 85 | transient=True, 86 | ) as progress: 87 | if archive_path.suffix == ".zip": 88 | with ZipFile(archive_path) as zf: 89 | tid = progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) 90 | extract_path = Path(zf.extract(members[0], download_dir)) 91 | if extract_type == "file": 92 | assert ( 93 | len(members) == 1 94 | ), f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" 95 | elif extract_type == "directory": 96 | for member in members[1:]: 97 | zf.extract(member, download_dir) 98 | progress.advance(tid) 99 | else: 100 | raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") 101 | 102 | elif archive_path.suffix in {".tar", ".gz"}: 103 | assert extract_type == "directory", f"Unexpected `{extract_type = }` for `tar` archive!" 104 | extract_path = download_dir / archive_path.stem.split(".")[0] 105 | with tarfile.open(archive_path) as tf: 106 | tid = progress.add_task("Extracting", aname=archive_path.name, total=len(members := tf.getmembers())) 107 | for member in members: 108 | tf.extract(member=member, path=download_dir) 109 | progress.advance(tid) 110 | 111 | # Cleanup (if specified) 112 | if cleanup: 113 | archive_path.unlink() 114 | 115 | return extract_path 116 | 117 | 118 | 119 | def download_extract(dataset_family: str, root_dir: Path, hf_token: str) -> None: 120 | """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" 121 | os.makedirs(download_dir := root_dir / "download" / dataset_family, exist_ok=True) 122 | 123 | # Download Files => Single-Threaded, with Progress Bar 124 | dl_tasks = [d for d in DATASET_REGISTRY[dataset_family]["download"] if not (download_dir / d["name"]).exists()] 125 | for dl_task in dl_tasks: 126 | if dl_task.get("hf_download", False): 127 | dl_path = download_with_progress(dl_task["url"], download_dir, hf_token=hf_token) 128 | else: 129 | dl_path = download_with_progress(dl_task["url"], download_dir) 130 | 131 | # Extract Files (if specified) 132 | if dl_task["extract"]: 133 | dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) 134 | 135 | # Rename Path --> dl_task["name"] 136 | if dl_task["do_rename"]: 137 | shutil.move(dl_path, download_dir / dl_task["name"]) 138 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/tasks/harnesses/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Dict, Optional, Protocol 3 | 4 | from vlm_eval.util.interfaces import VLM, ImageProcessor 5 | 6 | from .gqa import GQAScorer, GQATaskRunner 7 | from .okvqa import OKVQAScorer, OKVQATaskRunner 8 | from .ocidref import OCIDRefScorer, OCIDRefTaskRunner 9 | from .pope import PopeScorer, PopeTaskRunner 10 | from .refcoco import RefCOCOScorer, RefCOCOTaskRunner 11 | from .tallyqa import TallyQAScorer, TallyQATaskRunner 12 | from .textvqa import TextVQAScorer, TextVQATaskRunner 13 | from .vizwiz import VizWizScorer, VizWizTaskRunner 14 | from .vqav2 import VQAv2Scorer, VQAv2TaskRunner 15 | from .vsr import VSRScorer, VSRTaskRunner 16 | from .ai2d import AI2DScorer, AI2DTaskRunner 17 | from .mmmu import MMMUScorer, MMMUTaskRunner 18 | from .mmbench import MMBenchScorer, MMBenchTaskRunner 19 | from .mathvista import MathVistaScorer, MathVistaTaskRunner 20 | from .seedbench import SEEDBenchScorer, SEEDBenchTaskRunner 21 | from .mantis import MantisScorer, MantisTaskRunner 22 | from .mmstar import MMStarScorer, MMStarTaskRunner 23 | from .mscoco_karpathy import MSCOCOScorer, MSCOCOTaskRunner 24 | from .mmlu import MMLUScorer, MMLUTaskRunner 25 | 26 | # === Protocol Definitions === 27 | class TaskRunner(Protocol): 28 | def evaluate(self, vlm: VLM, device_batch_size: int, num_workers: int) -> None: 29 | ... 30 | 31 | 32 | class Scorer(Protocol): 33 | def score(self, model_id: str) -> Dict[str, float]: 34 | ... 35 | 36 | 37 | # === Task Runner Dispatch by Dataset Family === 38 | DATASET2RUNNER = { 39 | "vqa-v2": VQAv2TaskRunner, 40 | "gqa": GQATaskRunner, 41 | "okvqa": OKVQATaskRunner, 42 | "vizwiz": VizWizTaskRunner, 43 | "pope": PopeTaskRunner, 44 | "text-vqa": TextVQATaskRunner, 45 | "vsr": VSRTaskRunner, 46 | "tally-qa": TallyQATaskRunner, 47 | "refcoco": RefCOCOTaskRunner, 48 | "ocid-ref": OCIDRefTaskRunner, 49 | "ai2d": AI2DTaskRunner, 50 | "mmmu": MMMUTaskRunner, 51 | "mmbench": MMBenchTaskRunner, 52 | "mathvista": MathVistaTaskRunner, 53 | "seedbench": SEEDBenchTaskRunner, 54 | "mantis": MantisTaskRunner, 55 | "mmstar": MMStarTaskRunner, 56 | "mscoco_karpathy": MSCOCOTaskRunner, 57 | "mmlu": MMLUTaskRunner, 58 | } 59 | 60 | # === Score Function Dispatch by Dataset Family === 61 | DATASET2SCORER = { 62 | "vqa-v2": VQAv2Scorer, 63 | "gqa": GQAScorer, 64 | "okvqa": OKVQAScorer, 65 | "vizwiz": VizWizScorer, 66 | "pope": PopeScorer, 67 | "text-vqa": TextVQAScorer, 68 | "vsr": VSRScorer, 69 | "tally-qa": TallyQAScorer, 70 | "refcoco": RefCOCOScorer, 71 | "ocid-ref": OCIDRefScorer, 72 | "ai2d": AI2DScorer, 73 | "mmmu": MMMUScorer, 74 | "mmbench": MMBenchScorer, 75 | "mathvista": MathVistaScorer, 76 | "seedbench": SEEDBenchScorer, 77 | "mantis": MantisScorer, 78 | "mmstar": MMStarScorer, 79 | "mscoco_karpathy": MSCOCOScorer, 80 | "mmlu": MMLUScorer, 81 | } 82 | 83 | 84 | def get_task_runner( 85 | dataset_family: str, 86 | root_dir: Path, 87 | index_file: Path, 88 | task_results_dir: Path, 89 | model_id: str, 90 | prompt_fn: Callable[[str], str], 91 | image_processor: ImageProcessor, 92 | prompt_builder, 93 | ) -> TaskRunner: 94 | assert dataset_family in DATASET2RUNNER, f"Dataset Family `{dataset_family}` not supported!" 95 | return DATASET2RUNNER[dataset_family](root_dir, index_file, task_results_dir, model_id, prompt_fn, image_processor, prompt_builder) 96 | 97 | 98 | def get_scorer( 99 | dataset_family: str, 100 | dataset_id: str, 101 | task_results_dir: Path, 102 | full_results: Dict[str, Dict], 103 | annotations_file: Path, 104 | questions_file: Optional[Path] = None, 105 | split: str = "val", 106 | ) -> Scorer: 107 | assert dataset_family in DATASET2SCORER, f"Dataset Family `{dataset_family}` not supported!" 108 | return DATASET2SCORER[dataset_family]( 109 | dataset_id, task_results_dir, full_results, annotations_file, questions_file=questions_file, split=split 110 | ) 111 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/evaluation/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/gqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/evaluation/gqa/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/mmmu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/evaluation/mmmu/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/nocaps/metrics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This module is a collection of metrics commonly used during pretraining and 3 | downstream evaluation. Two main classes here are: 4 | 5 | - :class:`TopkAccuracy` used for ImageNet linear classification evaluation. 6 | - :class:`CocoCaptionsEvaluator` used for caption evaluation (CIDEr and SPICE). 7 | 8 | Parts of this module (:meth:`tokenize`, :meth:`cider` and :meth:`spice`) are 9 | adapted from `coco-captions evaluation code `_. 10 | """ 11 | from collections import defaultdict 12 | from typing import Dict, List 13 | 14 | import numpy as np 15 | 16 | 17 | # ------------------------------------------------------------------------- 18 | def to_ngrams(sentence: str, n: int = 4): 19 | r"""Convert a sentence into n-grams and their counts.""" 20 | words = sentence.split() 21 | counts = defaultdict(int) # type: ignore 22 | for k in range(1, n + 1): 23 | for i in range(len(words) - k + 1): 24 | ngram = tuple(words[i : i + k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | 29 | def counts2vec(cnts, document_frequency, log_reference_length, n: int = 4): 30 | r"""Function maps counts of ngram to vector of tfidf weights.""" 31 | vec = [defaultdict(float) for _ in range(n)] 32 | length = 0 33 | norm = [0.0 for _ in range(n)] 34 | for ngram, term_freq in cnts.items(): 35 | df = np.log(max(1.0, document_frequency[ngram])) 36 | # tf (term_freq) * idf (precomputed idf) for n-grams 37 | vec[len(ngram) - 1][ngram] = float(term_freq) * (log_reference_length - df) 38 | # Compute norm for the vector: will be used for computing similarity 39 | norm[len(ngram) - 1] += pow(vec[len(ngram) - 1][ngram], 2) 40 | 41 | if len(ngram) == 2: 42 | length += term_freq 43 | norm = [np.sqrt(nn) for nn in norm] 44 | return vec, norm, length 45 | 46 | 47 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref, n: int = 4, sigma: float = 6.0): 48 | r"""Compute the cosine similarity of two vectors.""" 49 | delta = float(length_hyp - length_ref) 50 | val = np.array([0.0 for _ in range(n)]) 51 | for nn in range(n): 52 | for ngram, _count in vec_hyp[nn].items(): 53 | val[nn] += min(vec_hyp[nn][ngram], vec_ref[nn][ngram]) * vec_ref[nn][ngram] 54 | 55 | val[nn] /= (norm_hyp[nn] * norm_ref[nn]) or 1 56 | val[nn] *= np.e ** (-(delta**2) / (2 * sigma**2)) 57 | return val 58 | 59 | 60 | # ------------------------------------------------------------------------- 61 | 62 | 63 | def cider( 64 | predictions: Dict[int, List[str]], 65 | ground_truth: Dict[int, List[str]], 66 | n: int = 4, 67 | sigma: float = 6.0, 68 | ) -> float: 69 | r"""Compute CIDEr score given ground truth captions and predictions.""" 70 | 71 | ctest = [to_ngrams(predictions[image_id][0]) for image_id in ground_truth] 72 | crefs = [[to_ngrams(gt) for gt in ground_truth[image_id]] for image_id in ground_truth] 73 | # Build document frequency and compute IDF. 74 | document_frequency = defaultdict(float) 75 | for refs in crefs: 76 | # refs, k ref captions of one image 77 | for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]): 78 | document_frequency[ngram] += 1 79 | 80 | # Compute log reference length. 81 | log_reference_length = np.log(float(len(crefs))) 82 | 83 | scores = [] 84 | for test, refs in zip(ctest, crefs): 85 | # Compute vector for test captions. 86 | vec, norm, length = counts2vec(test, document_frequency, log_reference_length, n) 87 | # Compute vector for ref captions. 88 | score = np.array([0.0 for _ in range(n)]) 89 | for ref in refs: 90 | vec_ref, norm_ref, length_ref = counts2vec(ref, document_frequency, log_reference_length, n) 91 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref, n, sigma) 92 | 93 | score_avg = np.mean(score) 94 | score_avg /= len(refs) 95 | score_avg *= 10.0 96 | scores.append(score_avg) 97 | 98 | return np.mean(scores) 99 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/textvqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/evaluation/textvqa/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/vizwiz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/evaluation/vizwiz/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/evaluation/vqav2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/evaluation/vqav2/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/interfaces.py: -------------------------------------------------------------------------------- 1 | """ 2 | interfaces.py 3 | 4 | Protocol/type definitions for the common parts of the VLM training & inference pipelines, from base processors 5 | (e.g., ImageProcessors) to complete vision-language models (VLMs). 6 | """ 7 | from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | from PIL.Image import Image 12 | from transformers.tokenization_utils import BatchEncoding 13 | 14 | 15 | # === Processor & Tokenizer Interface Definitions === 16 | class Tokenizer(Protocol): 17 | padding_side: str 18 | pad_token_id: int 19 | 20 | def __call__(self, text: Union[str, Sequence[str]], return_tensors: str = "pt", **kwargs) -> BatchEncoding: 21 | ... 22 | 23 | def encode(self, inputs: str, add_special_tokens: bool = False) -> List[int]: 24 | ... 25 | 26 | def decode(self, output_ids: Union[torch.Tensor, Sequence[int]], **kwargs) -> str: 27 | ... 28 | 29 | def batch_decode(self, output_ids: Union[torch.Tensor, Sequence[Sequence[int]]], **kwargs) -> List[str]: 30 | ... 31 | 32 | 33 | class ImageProcessor(Protocol): 34 | def __call__(self, img: Image, **kwargs) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 35 | ... 36 | 37 | 38 | # === General VLM Inference Interface === 39 | class VLM(Protocol): 40 | image_processor: ImageProcessor 41 | 42 | def load(self) -> Tuple[nn.Module, Tokenizer, ImageProcessor]: 43 | ... 44 | 45 | def get_prompt_builder(self, system_prompt: Optional[str] = None, use_qa_builder: Optional[bool] = True) -> Any: 46 | ... 47 | 48 | def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]: 49 | ... 50 | 51 | def generate_answer( 52 | self, 53 | pixel_values: torch.Tensor, 54 | question_prompts: List[str], 55 | return_string_probabilities: Optional[List[str]] = None, 56 | ) -> Union[List[str], List[List[float]]]: 57 | ... 58 | 59 | def generate( 60 | self, 61 | image: Image, 62 | input_text: str, 63 | do_sample: bool, 64 | temperature: float, 65 | max_new_tokens: int, 66 | min_length: int, 67 | length_penalty: float, 68 | **kwargs, 69 | ) -> str: 70 | ... 71 | -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/loading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victorwz/Open-Qwen2VL/f7a2ebc649086cc254a135fab3d2d4adcd680add/vlm-evaluation/vlm_eval/util/loading/__init__.py -------------------------------------------------------------------------------- /vlm-evaluation/vlm_eval/util/preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | preprocessing.py 3 | 4 | Various preprocessing utilities for handling/cleaning text (questions, punctuation), paths & temporary files, and 5 | general functions for better quality-of-life. 6 | """ 7 | import re 8 | 9 | 10 | # Lifted from LAVIS (`lavis.processors.blip_processors.py :: BLIPQuestionProcessor.pre_question`) 11 | def process_question(question: str, max_words: int = 128) -> str: 12 | question = re.sub(r"([.!\"()*#:;~])", "", question.lower()) 13 | question = question.rstrip(" ") 14 | 15 | # Truncate Question 16 | question_words = question.split(" ") 17 | if len(question_words) > max_words: 18 | question = " ".join(question_words[:max_words]) 19 | 20 | return question 21 | --------------------------------------------------------------------------------