├── .gitignore ├── LICENSE ├── README.md ├── assets └── pipeline.png ├── demo.py ├── docs ├── .DS_Store ├── index.html ├── sitemap.xml └── static │ ├── .DS_Store │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js │ └── proxyv │ ├── mask_ratio_revised.png │ ├── non_spatial_proxyv.png │ ├── proxyv.png │ ├── results.png │ ├── table1.png │ ├── table2.png │ └── table3.png ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_phi3.py │ │ └── llava_qwen.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── dev_eva_clip │ │ │ ├── eva_clip │ │ │ │ ├── __init__.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── constants.py │ │ │ │ ├── eva_vit_model.py │ │ │ │ ├── factory.py │ │ │ │ ├── hf_configs.py │ │ │ │ ├── hf_model.py │ │ │ │ ├── loss.py │ │ │ │ ├── model.py │ │ │ │ ├── model_configs │ │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ │ │ ├── modified_resnet.py │ │ │ │ ├── openai.py │ │ │ │ ├── pretrained.py │ │ │ │ ├── rope.py │ │ │ │ ├── timm_model.py │ │ │ │ ├── tokenizer.py │ │ │ │ ├── transform.py │ │ │ │ ├── transformer.py │ │ │ │ └── utils.py │ │ │ └── eva_vit.py │ │ ├── eva_clip │ │ │ ├── eva_clip_encoder.py │ │ │ ├── eva_clip_processors.py │ │ │ ├── eva_vit.py │ │ │ ├── factory.py │ │ │ └── model_configs │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ ├── hf_vision.py │ │ ├── imagebind.py │ │ ├── open_clip_encoder.py │ │ └── siglip_encoder.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── masked_drop.py │ │ ├── perceiver.py │ │ ├── qformer.py │ │ └── spatial_pool.py │ ├── utils.py │ └── vision_mlp.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── llava_trainer_eval.py │ ├── train.py │ └── train_mem.py └── utils.py ├── pyproject.toml └── scripts ├── finetune ├── finetune_vicuna7b_baseline.sh └── finetune_vicuna7b_proxyv.sh ├── pretrain ├── pretrain_llama8b_baseline.sh ├── pretrain_llama8b_proxyv.sh ├── pretrain_phi3b_baseline.sh ├── pretrain_phi3b_proxyv.sh ├── pretrain_qwen7b_baseline.sh ├── pretrain_qwen7b_proxyv.sh ├── pretrain_vicuna13b_baselin.sh ├── pretrain_vicuna13b_proxyv.sh ├── pretrain_vicuna7b_baseline.sh └── pretrain_vicuna7b_proxyv.sh ├── zero2.json ├── zero2_fused_adamw.json ├── zero2_offload.json ├── zero3.json ├── zero3_offload.json └── zero3pp.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProxyV: Streamline Without Sacrifice - Squeeze out Computation Redundancy in LMM 2 | 3 | [![Static Badge](https://img.shields.io/badge/proxyv-paper-red)](https://arxiv.org/abs/2505.15816) 4 | [![Static Badge](https://img.shields.io/badge/proxyv-project_page-green)](https://penghao-wu.github.io/ProxyV/) 5 | [![Static Badge](https://img.shields.io/badge/proxyv-model-blue)](https://huggingface.co/craigwu/proxyv_vicuna_7b_layer12) 6 | 7 | ![pipeline](assets/pipeline.png) 8 | 9 | ## Contents: 10 | 1. [Getting Started](#start) 11 | 2. [Image Encoding Scheme](#encoding) 12 | 3. [Training](#training) 13 | 4. [Evaluation](#evaluation) 14 | 5. [License](#license) 15 | 6. [Citation](#citation) 16 | 7. [Acknowledgement](#acknowledgement) 17 | 18 | ## Getting Started 19 | 20 | ### Installation 21 | ``` 22 | conda create -n proxyv python=3.10 -y 23 | conda activate proxyv 24 | pip install --upgrade pip # Enable PEP 660 support. 25 | pip install -e ".[train]" 26 | ``` 27 | 28 | ### Training Dataset 29 | For the pre-training stage, we use the 1.2M ShareGPT4V data which can be downloaded at this [link](https://huggingface.co/datasets/Lin-Chen/ShareGPT4V) 30 | For the fine-tuning stage, we use the public LLaVA-Next data which can be downloaded at this [link](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Data) 31 | 32 | ## Image Encoding Scheme 33 | In our current implementation, we adopt the [AnyRes](https://github.com/LLaVA-VL/LLaVA-NeXT) strategy. The image features within each crop are flattened in raster order and concatenated crop by crop, similar to the [UniRes](https://github.com/EvolvingLMMs-Lab/LongVA) strategy. We also append a newline separator token after each crop. 34 | To process the vision tokens more conveniently, we pack tokens in the **\[vision tokens; proxy tokens; newline separator tokens; text tokens\]** order, and modify the position_ids and attention_masks accordingly to preserve their original relative order. 35 | 36 | ## Training 37 | The pre-training scripts can be found within the ``scripts/pretrain`` folder, and fine-tuning example scripts are provided under the ``scripts/finetune`` folder. 38 | To enable ProxyV, set ``--proxyv`` to ``true`` in the script and set ``--proxyv_start_layer`` to the desired layer index. 39 | 40 | ## Evaluation 41 | The vicuna-1.5-7B ProxyV layer-12 model studied in the paper is provided at this [link](vicuna-1.5-7B ProxyV layer-12 model). 42 | A simple inference example script is provided at ``demo.py``. 43 | All benchmark evaluations can be directly conducted using [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) with `--model` set to `llava`. 44 | 45 | ## License 46 | 47 | This project is under the Apache-2.0 license. See [LICENSE](LICENSE) for details. 48 | 49 | ## Citation 50 | Please consider citing our paper if you find this project helpful for your research: 51 | 52 | ```bibtex 53 | @article{ProxyV, 54 | author = {Wu, Penghao and Lu, Lewei and Liu, Ziwei}, 55 | title = {Streamline Without Sacrifice - Squeeze out Computation Redundancy in LMM}, 56 | journal={arXiv preprint arXiv:2505.15816}, 57 | year={2025}} 58 | ``` 59 | 60 | ## Acknowledgement 61 | - This work is built upon [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT). 62 | 63 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/assets/pipeline.png -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | 4 | 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 6 | from llava.conversation import conv_templates 7 | from llava.model.builder import load_pretrained_model 8 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 9 | 10 | 11 | 12 | def process(image, question, tokenizer, image_processor, model_config): 13 | qs = question 14 | if model_config.mm_use_im_start_end: 15 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 16 | else: 17 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 18 | 19 | conv = conv_templates[conv_mode].copy() 20 | conv.append_message(conv.roles[0], qs) 21 | conv.append_message(conv.roles[1], None) 22 | prompt = conv.get_prompt() 23 | print(prompt) 24 | 25 | image_size = [image.size] 26 | image_tensor = process_images([image], image_processor, model_config).to(torch.float16) 27 | # image_tensor = process_images([image], image_processor, model_config).to(torch.bfloat16) 28 | 29 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 30 | 31 | return input_ids, image_tensor, image_size, prompt 32 | 33 | 34 | conv_mode = "v1" 35 | temperature = 0.0 36 | model_path = "craigwu/proxyv_vicuna_7b_layer12" 37 | model_name = get_model_name_from_path(model_path) 38 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, multimodal=True,attn_implementation='sdpa') 39 | 40 | while True: 41 | image_path = input("image path: ") 42 | image = Image.open(image_path).convert('RGB') 43 | question = input("question: ") 44 | 45 | input_ids, image_tensor, image_sizes, prompt = process(image, question, tokenizer, image_processor, model.config) 46 | input_ids = input_ids.to(device='cuda', non_blocking=True) 47 | 48 | with torch.inference_mode(): 49 | output_ids = model.generate( 50 | input_ids, 51 | images=image_tensor, 52 | image_sizes=image_sizes, 53 | do_sample=True if temperature > 0 else False, 54 | temperature=temperature, 55 | num_beams=1, 56 | max_new_tokens=512, 57 | use_cache=True) 58 | 59 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 60 | 61 | print(outputs) -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/.DS_Store -------------------------------------------------------------------------------- /docs/sitemap.xml: -------------------------------------------------------------------------------- 1 | 2 | 7 | 8 | 9 | 10 | 11 | https://amap-ml.github.io/UniVG-R1-page/ 12 | 2025-05-20T08:20:38+00:00 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /docs/static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/.DS_Store -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | .content-flex-row { 5 | display: flex; 6 | align-items: flex-start; 7 | justify-content: center; 8 | gap: 32px; /* 左右间距可调整 */ 9 | margin-bottom: 2rem; 10 | } 11 | 12 | .footer .icon-link { 13 | font-size: 25px; 14 | color: #000; 15 | } 16 | 17 | .link-block a { 18 | margin-top: 5px; 19 | margin-bottom: 5px; 20 | } 21 | 22 | .dnerf { 23 | font-variant: small-caps; 24 | } 25 | 26 | 27 | .teaser .hero-body { 28 | padding-top: 0; 29 | padding-bottom: 3rem; 30 | } 31 | 32 | .teaser { 33 | font-family: 'Google Sans', sans-serif; 34 | } 35 | 36 | 37 | .publication-title { 38 | } 39 | 40 | .publication-banner { 41 | max-height: parent; 42 | 43 | } 44 | 45 | .publication-banner video { 46 | position: relative; 47 | left: auto; 48 | top: auto; 49 | transform: none; 50 | object-fit: fit; 51 | } 52 | 53 | .publication-header .hero-body { 54 | } 55 | 56 | .publication-title { 57 | font-family: 'Google Sans', sans-serif; 58 | } 59 | 60 | .publication-authors { 61 | font-family: 'Google Sans', sans-serif; 62 | } 63 | 64 | .publication-venue { 65 | color: #555; 66 | width: fit-content; 67 | font-weight: bold; 68 | } 69 | 70 | .publication-awards { 71 | color: #ff3860; 72 | width: fit-content; 73 | font-weight: bolder; 74 | } 75 | 76 | .publication-authors { 77 | } 78 | 79 | .publication-authors a { 80 | color: hsl(204, 86%, 53%) !important; 81 | } 82 | 83 | .publication-authors a:hover { 84 | text-decoration: underline; 85 | } 86 | 87 | .author-block { 88 | display: inline-block; 89 | } 90 | 91 | .publication-banner img { 92 | } 93 | 94 | .publication-authors { 95 | /*color: #4286f4;*/ 96 | } 97 | 98 | .publication-video { 99 | position: relative; 100 | width: 100%; 101 | height: 0; 102 | padding-bottom: 56.25%; 103 | 104 | overflow: hidden; 105 | border-radius: 10px !important; 106 | } 107 | 108 | .publication-video iframe { 109 | position: absolute; 110 | top: 0; 111 | left: 0; 112 | width: 100%; 113 | height: 100%; 114 | } 115 | 116 | .publication-body img { 117 | } 118 | 119 | .results-carousel { 120 | overflow: hidden; 121 | } 122 | 123 | .results-carousel .item { 124 | margin: 5px; 125 | overflow: hidden; 126 | border: 1px solid #bbb; 127 | border-radius: 10px; 128 | padding: 0; 129 | font-size: 0; 130 | } 131 | 132 | .results-carousel video { 133 | margin: 0; 134 | } 135 | 136 | 137 | .interpolation-panel { 138 | background: #f5f5f5; 139 | border-radius: 10px; 140 | } 141 | 142 | .interpolation-panel .interpolation-image { 143 | width: 100%; 144 | border-radius: 5px; 145 | } 146 | 147 | .interpolation-video-column { 148 | } 149 | 150 | .interpolation-panel .slider { 151 | margin: 0 !important; 152 | } 153 | 154 | .interpolation-panel .slider { 155 | margin: 0 !important; 156 | } 157 | 158 | #interpolation-image-wrapper { 159 | width: 100%; 160 | } 161 | #interpolation-image-wrapper img { 162 | border-radius: 5px; 163 | } 164 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /docs/static/proxyv/mask_ratio_revised.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/mask_ratio_revised.png -------------------------------------------------------------------------------- /docs/static/proxyv/non_spatial_proxyv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/non_spatial_proxyv.png -------------------------------------------------------------------------------- /docs/static/proxyv/proxyv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/proxyv.png -------------------------------------------------------------------------------- /docs/static/proxyv/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/results.png -------------------------------------------------------------------------------- /docs/static/proxyv/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/table1.png -------------------------------------------------------------------------------- /docs/static/proxyv/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/table2.png -------------------------------------------------------------------------------- /docs/static/proxyv/table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/table3.png -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_phi3": "LlavaPhiForCausalLM, LlavaPhiConfig", 7 | # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", 8 | # Add other models as needed 9 | } 10 | 11 | for model_name, model_classes in AVAILABLE_MODELS.items(): 12 | exec(f"from .language_model.{model_name} import {model_classes}") 13 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model import * 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if 'clip' in vision_tower or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 34 | 35 | raise ValueError(f"Unknown vision tower: {vision_tower}") 36 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from llava.utils import rank0_print 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | try: 7 | from s2wrapper import forward as multiscale_forward 8 | except: 9 | pass 10 | 11 | 12 | class CLIPVisionTower(nn.Module): 13 | def __init__(self, vision_tower, args, delay_load=False): 14 | super().__init__() 15 | 16 | self.is_loaded = False 17 | 18 | self.vision_tower_name = vision_tower 19 | self.select_layer = args.mm_vision_select_layer 20 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 21 | 22 | if not delay_load: 23 | rank0_print(f"Loading vision tower: {vision_tower}") 24 | self.load_model() 25 | elif getattr(args, "unfreeze_mm_vision_tower", False): 26 | # TODO: better detector is needed. 27 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 28 | self.load_model() 29 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 30 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 31 | self.load_model() 32 | else: 33 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 34 | 35 | def load_model(self, device_map=None): 36 | if self.is_loaded: 37 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) 38 | return 39 | 40 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 41 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 42 | self.vision_tower.requires_grad_(False) 43 | 44 | self.is_loaded = True 45 | 46 | def feature_select(self, image_forward_outs): 47 | select_feature_type = self.select_feature 48 | 49 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 50 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 51 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 52 | select_feature_type = select_feature_type.replace("slicefour_", "") 53 | elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 54 | select_layers = [-2, -5, -8, -11, 6] 55 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) 56 | select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 57 | else: 58 | image_features = image_forward_outs.hidden_states[self.select_layer] 59 | 60 | if select_feature_type == "patch": 61 | image_features = image_features[:, 1:] 62 | elif select_feature_type == "cls_patch": 63 | image_features = image_features 64 | else: 65 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 66 | return image_features 67 | 68 | def forward(self, images): 69 | if type(images) is list: 70 | image_features = [] 71 | for image in images: 72 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 73 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 74 | image_features.append(image_feature) 75 | else: 76 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 77 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 78 | 79 | return image_features 80 | 81 | @property 82 | def dummy_feature(self): 83 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 84 | 85 | @property 86 | def dtype(self): 87 | return self.vision_tower.dtype 88 | 89 | @property 90 | def device(self): 91 | return self.vision_tower.device 92 | 93 | @property 94 | def config(self): 95 | if self.is_loaded: 96 | return self.vision_tower.config 97 | else: 98 | return self.cfg_only 99 | 100 | @property 101 | def hidden_size(self): 102 | _hidden_size = self.config.hidden_size 103 | if "slicefour" in self.select_feature: 104 | _hidden_size *= 4 105 | if "slice_m25811_f6" in self.select_feature: 106 | _hidden_size *= 5 107 | return _hidden_size 108 | 109 | @property 110 | def num_patches_per_side(self): 111 | return self.config.image_size // self.config.patch_size 112 | 113 | @property 114 | def num_patches(self): 115 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 116 | if "cls_patch" in self.select_feature: 117 | _num_patches += 1 118 | return _num_patches 119 | 120 | @property 121 | def image_size(self): 122 | return self.config.image_size 123 | 124 | 125 | class CLIPVisionTowerS2(CLIPVisionTower): 126 | def __init__(self, vision_tower, args, delay_load=False): 127 | 128 | self.s2_scales = getattr(args, "s2_scales", "336,672,1008") 129 | self.s2_scales = list(map(int, self.s2_scales.split(","))) 130 | self.s2_scales.sort() 131 | self.s2_split_size = self.s2_scales[0] 132 | self.s2_image_size = self.s2_scales[-1] 133 | 134 | super().__init__(vision_tower, args, delay_load) 135 | 136 | # change resize/crop size in preprocessing to the largest image size in s2_scale 137 | if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): 138 | self.image_processor.size["shortest_edge"] = self.s2_image_size 139 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size 140 | 141 | def load_model(self, device_map=None): 142 | if self.is_loaded: 143 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) 144 | return 145 | 146 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 147 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 148 | self.vision_tower.requires_grad_(False) 149 | 150 | self.image_processor.size["shortest_edge"] = self.s2_image_size 151 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size 152 | 153 | self.is_loaded = True 154 | 155 | def forward_feature(self, images): 156 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 157 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 158 | return image_features 159 | 160 | def forward(self, images): 161 | if type(images) is list: 162 | image_features = [] 163 | for image in images: 164 | image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) 165 | image_features.append(image_feature) 166 | else: 167 | image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) 168 | 169 | return image_features 170 | 171 | @property 172 | def hidden_size(self): 173 | return self.config.hidden_size * len(self.s2_scales) 174 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 8 | from .tokenizer import SimpleTokenizer, tokenize 9 | from .transform import image_transform 10 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings", 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings", 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens", 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings", 54 | }, 55 | "pooler": "mean_pooler", 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | 10 | has_distributed = True 11 | except ImportError: 12 | has_distributed = False 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy 20 | 21 | 22 | def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): 23 | assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." 24 | if use_horovod: 25 | assert hvd is not None, "Please install horovod" 26 | if gather_with_grad: 27 | all_image_features = hvd.allgather(image_features) 28 | all_text_features = hvd.allgather(text_features) 29 | else: 30 | with torch.no_grad(): 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | if not local_loss: 34 | # ensure grads for local rank when all_* features don't have a gradient 35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 37 | gathered_image_features[rank] = image_features 38 | gathered_text_features[rank] = text_features 39 | all_image_features = torch.cat(gathered_image_features, dim=0) 40 | all_text_features = torch.cat(gathered_text_features, dim=0) 41 | else: 42 | # We gather tensors from all gpus 43 | if gather_with_grad: 44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 46 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 47 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 48 | else: 49 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 50 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 51 | dist.all_gather(gathered_image_features, image_features) 52 | dist.all_gather(gathered_text_features, text_features) 53 | if not local_loss: 54 | # ensure grads for local rank when all_* features don't have a gradient 55 | gathered_image_features[rank] = image_features 56 | gathered_text_features[rank] = text_features 57 | all_image_features = torch.cat(gathered_image_features, dim=0) 58 | all_text_features = torch.cat(gathered_text_features, dim=0) 59 | 60 | return all_image_features, all_text_features 61 | 62 | 63 | class ClipLoss(nn.Module): 64 | 65 | def __init__( 66 | self, 67 | local_loss=False, 68 | gather_with_grad=False, 69 | cache_labels=False, 70 | rank=0, 71 | world_size=1, 72 | use_horovod=False, 73 | smoothing=0.0, 74 | ): 75 | super().__init__() 76 | self.local_loss = local_loss 77 | self.gather_with_grad = gather_with_grad 78 | self.cache_labels = cache_labels 79 | self.rank = rank 80 | self.world_size = world_size 81 | self.use_horovod = use_horovod 82 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale=1.0): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 92 | 93 | if self.local_loss: 94 | logits_per_image = logit_scale * image_features @ all_text_features.T 95 | logits_per_text = logit_scale * text_features @ all_image_features.T 96 | else: 97 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 98 | logits_per_text = logits_per_image.T 99 | else: 100 | logits_per_image = logit_scale * image_features @ text_features.T 101 | logits_per_text = logit_scale * text_features @ image_features.T 102 | # calculated ground-truth and cache if enabled 103 | num_logits = logits_per_image.shape[0] 104 | if self.prev_num_logits != num_logits or device not in self.labels: 105 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 106 | if self.world_size > 1 and self.local_loss: 107 | labels = labels + num_logits * self.rank 108 | if self.cache_labels: 109 | self.labels[device] = labels 110 | self.prev_num_logits = num_logits 111 | else: 112 | labels = self.labels[device] 113 | 114 | if self.label_smoothing_cross_entropy: 115 | total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 116 | else: 117 | total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 118 | 119 | acc = None 120 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 121 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 122 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 123 | return total_loss, acc 124 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))])) 37 | 38 | def forward(self, x: torch.Tensor): 39 | identity = x 40 | 41 | out = self.act1(self.bn1(self.conv1(x))) 42 | out = self.act2(self.bn2(self.conv2(out))) 43 | out = self.avgpool(out) 44 | out = self.bn3(self.conv3(out)) 45 | 46 | if self.downsample is not None: 47 | identity = self.downsample(x) 48 | 49 | out += identity 50 | out = self.act3(out) 51 | return out 52 | 53 | 54 | class AttentionPool2d(nn.Module): 55 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 56 | super().__init__() 57 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) 58 | self.k_proj = nn.Linear(embed_dim, embed_dim) 59 | self.q_proj = nn.Linear(embed_dim, embed_dim) 60 | self.v_proj = nn.Linear(embed_dim, embed_dim) 61 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 62 | self.num_heads = num_heads 63 | 64 | def forward(self, x): 65 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 66 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 67 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 68 | x, _ = F.multi_head_attention_forward( 69 | query=x, 70 | key=x, 71 | value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0.0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False, 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.image_size = image_size 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.act1 = nn.ReLU(inplace=True) 110 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(width // 2) 112 | self.act2 = nn.ReLU(inplace=True) 113 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d(width) 115 | self.act3 = nn.ReLU(inplace=True) 116 | self.avgpool = nn.AvgPool2d(2) 117 | 118 | # residual layers 119 | self._inplanes = width # this is a *mutable* variable used during construction 120 | self.layer1 = self._make_layer(width, layers[0]) 121 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 124 | 125 | embed_dim = width * 32 # the ResNet feature dimension 126 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 127 | 128 | self.init_parameters() 129 | 130 | def _make_layer(self, planes, blocks, stride=1): 131 | layers = [Bottleneck(self._inplanes, planes, stride)] 132 | 133 | self._inplanes = planes * Bottleneck.expansion 134 | for _ in range(1, blocks): 135 | layers.append(Bottleneck(self._inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def init_parameters(self): 140 | if self.attnpool is not None: 141 | std = self.attnpool.c_proj.in_features**-0.5 142 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 143 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 144 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 146 | 147 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 148 | for name, param in resnet_block.named_parameters(): 149 | if name.endswith("bn3.weight"): 150 | nn.init.zeros_(param) 151 | 152 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 153 | assert unlocked_groups == 0, "partial locking not currently supported for this model" 154 | for param in self.parameters(): 155 | param.requires_grad = False 156 | if freeze_bn_stats: 157 | freeze_batch_norm_2d(self) 158 | 159 | @torch.jit.ignore 160 | def set_grad_checkpointing(self, enable=True): 161 | # FIXME support for non-transformer 162 | pass 163 | 164 | def stem(self, x): 165 | x = self.act1(self.bn1(self.conv1(x))) 166 | x = self.act2(self.bn2(self.conv2(x))) 167 | x = self.act3(self.bn3(self.conv3(x))) 168 | x = self.avgpool(x) 169 | return x 170 | 171 | def forward(self, x): 172 | x = self.stem(x) 173 | x = self.layer1(x) 174 | x = self.layer2(x) 175 | x = self.layer3(x) 176 | x = self.layer4(x) 177 | x = self.attnpool(x) 178 | 179 | return x 180 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag("openai") 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = "fp32" if device == "cpu" else "fp16" 56 | 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith("amp") or precision == "fp32": 87 | model.float() 88 | elif precision == "bf16": 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == "fp32": 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | 8 | def broadcat(tensors, dim=-1): 9 | num_tensors = len(tensors) 10 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 11 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" 12 | shape_len = list(shape_lens)[0] 13 | dim = (dim + shape_len) if dim < 0 else dim 14 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 15 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 16 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" 17 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 18 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 19 | expanded_dims.insert(dim, (dim, dims[dim])) 20 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 21 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 22 | return torch.cat(tensors, dim=dim) 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, "... (d r) -> ... d r", r=2) 27 | x1, x2 = x.unbind(dim=-1) 28 | x = torch.stack((-x2, x1), dim=-1) 29 | return rearrange(x, "... d r -> ... (d r)") 30 | 31 | 32 | class VisionRotaryEmbedding(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | pt_seq_len, 37 | ft_seq_len=None, 38 | custom_freqs=None, 39 | freqs_for="lang", 40 | theta=10000, 41 | max_freq=10, 42 | num_freqs=1, 43 | ): 44 | super().__init__() 45 | if custom_freqs: 46 | freqs = custom_freqs 47 | elif freqs_for == "lang": 48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 49 | elif freqs_for == "pixel": 50 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 51 | elif freqs_for == "constant": 52 | freqs = torch.ones(num_freqs).float() 53 | else: 54 | raise ValueError(f"unknown modality {freqs_for}") 55 | 56 | if ft_seq_len is None: 57 | ft_seq_len = pt_seq_len 58 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 59 | 60 | freqs_h = torch.einsum("..., f -> ... f", t, freqs) 61 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) 62 | 63 | freqs_w = torch.einsum("..., f -> ... f", t, freqs) 64 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) 65 | 66 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) 67 | 68 | self.register_buffer("freqs_cos", freqs.cos()) 69 | self.register_buffer("freqs_sin", freqs.sin()) 70 | 71 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 72 | 73 | def forward(self, t, start_index=0): 74 | rot_dim = self.freqs_cos.shape[-1] 75 | end_index = start_index + rot_dim 76 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 78 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 79 | 80 | return torch.cat((t_left, t, t_right), dim=-1) 81 | 82 | 83 | class VisionRotaryEmbeddingFast(nn.Module): 84 | def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): 85 | super().__init__() 86 | if custom_freqs: 87 | freqs = custom_freqs 88 | elif freqs_for == "lang": 89 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 90 | elif freqs_for == "pixel": 91 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 92 | elif freqs_for == "constant": 93 | freqs = torch.ones(num_freqs).float() 94 | else: 95 | raise ValueError(f"unknown modality {freqs_for}") 96 | 97 | if ft_seq_len is None: 98 | ft_seq_len = pt_seq_len 99 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 100 | 101 | freqs = torch.einsum("..., f -> ... f", t, freqs) 102 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 103 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 104 | 105 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 106 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 107 | 108 | self.patch_dropout = patch_dropout 109 | 110 | self.register_buffer("freqs_cos", freqs_cos) 111 | self.register_buffer("freqs_sin", freqs_sin) 112 | 113 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 114 | 115 | def forward(self, t, patch_indices_keep=None): 116 | if patch_indices_keep is not None: 117 | batch = t.size()[0] 118 | batch_indices = torch.arange(batch) 119 | batch_indices = batch_indices[..., None] 120 | 121 | freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 122 | freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 123 | 124 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 125 | freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") 126 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 127 | freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") 128 | 129 | return t * freqs_cos + rotate_half(t) * freqs_sin 130 | 131 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 132 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | import timm 14 | from timm.models.layers import Mlp, to_2tuple 15 | 16 | try: 17 | # old timm imports < 0.8.1 18 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 19 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 20 | except ImportError: 21 | # new timm imports >= 0.8.1 22 | from timm.layers import RotAttentionPool2d 23 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 24 | except ImportError: 25 | timm = None 26 | 27 | from .utils import freeze_batch_norm_2d 28 | 29 | 30 | class TimmModel(nn.Module): 31 | """timm model adapter 32 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 33 | """ 34 | 35 | def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): 36 | super().__init__() 37 | if timm is None: 38 | raise RuntimeError("Please `pip install timm` to use timm models.") 39 | 40 | self.image_size = to_2tuple(image_size) 41 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 42 | feat_size = self.trunk.default_cfg.get("pool_size", None) 43 | feature_ndim = 1 if not feat_size else 2 44 | if pool in ("abs_attn", "rot_attn"): 45 | assert feature_ndim == 2 46 | # if attn pooling used, remove both classifier and default pool 47 | self.trunk.reset_classifier(0, global_pool="") 48 | else: 49 | # reset global pool if pool config set, otherwise leave as network default 50 | reset_kwargs = dict(global_pool=pool) if pool else {} 51 | self.trunk.reset_classifier(0, **reset_kwargs) 52 | prev_chs = self.trunk.num_features 53 | 54 | head_layers = OrderedDict() 55 | if pool == "abs_attn": 56 | head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 57 | prev_chs = embed_dim 58 | elif pool == "rot_attn": 59 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 60 | prev_chs = embed_dim 61 | else: 62 | assert proj, "projection layer needed if non-attention pooling is used." 63 | 64 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 65 | if proj == "linear": 66 | head_layers["drop"] = nn.Dropout(drop) 67 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 68 | elif proj == "mlp": 69 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 70 | 71 | self.head = nn.Sequential(head_layers) 72 | 73 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 74 | """lock modules 75 | Args: 76 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 77 | """ 78 | if not unlocked_groups: 79 | # lock full model 80 | for param in self.trunk.parameters(): 81 | param.requires_grad = False 82 | if freeze_bn_stats: 83 | freeze_batch_norm_2d(self.trunk) 84 | else: 85 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 86 | try: 87 | # FIXME import here until API stable and in an official release 88 | from timm.models.helpers import group_parameters, group_modules 89 | except ImportError: 90 | raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") 91 | matcher = self.trunk.group_matcher() 92 | gparams = group_parameters(self.trunk, matcher) 93 | max_layer_id = max(gparams.keys()) 94 | max_layer_id = max_layer_id - unlocked_groups 95 | for group_idx in range(max_layer_id + 1): 96 | group = gparams[group_idx] 97 | for param in group: 98 | self.trunk.get_parameter(param).requires_grad = False 99 | if freeze_bn_stats: 100 | gmodules = group_modules(self.trunk, matcher, reverse=True) 101 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 102 | freeze_batch_norm_2d(self.trunk, gmodules) 103 | 104 | @torch.jit.ignore 105 | def set_grad_checkpointing(self, enable=True): 106 | try: 107 | self.trunk.set_grad_checkpointing(enable) 108 | except Exception as e: 109 | logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") 110 | 111 | def forward(self, x): 112 | x = self.trunk(x) 113 | x = self.head(x) 114 | return x 115 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == "min" else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert("RGB") 40 | 41 | 42 | # class CatGen(nn.Module): 43 | # def __init__(self, num=4): 44 | # self.num = num 45 | # def mixgen_batch(image, text): 46 | # batch_size = image.shape[0] 47 | # index = np.random.permutation(batch_size) 48 | 49 | # cat_images = [] 50 | # for i in range(batch_size): 51 | # # image mixup 52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 53 | # # text concat 54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 55 | # text = torch.stack(text) 56 | # return image, text 57 | 58 | 59 | def image_transform( 60 | image_size: int, 61 | is_train: bool, 62 | mean: Optional[Tuple[float, ...]] = None, 63 | std: Optional[Tuple[float, ...]] = None, 64 | resize_longest_max: bool = False, 65 | fill_color: int = 0, 66 | ): 67 | mean = mean or OPENAI_DATASET_MEAN 68 | if not isinstance(mean, (list, tuple)): 69 | mean = (mean,) * 3 70 | 71 | std = std or OPENAI_DATASET_STD 72 | if not isinstance(std, (list, tuple)): 73 | std = (std,) * 3 74 | 75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 77 | image_size = image_size[0] 78 | 79 | normalize = Normalize(mean=mean, std=std) 80 | if is_train: 81 | return Compose( 82 | [ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ] 88 | ) 89 | else: 90 | if resize_longest_max: 91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 92 | else: 93 | transforms = [ 94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 95 | CenterCrop(image_size), 96 | ] 97 | transforms.extend( 98 | [ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ] 103 | ) 104 | return Compose(transforms) 105 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py: -------------------------------------------------------------------------------- 1 | # Based on EVA, BEIT, timm and DeiT code bases 2 | # https://github.com/baaivision/EVA 3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | # not tested yet 9 | import math 10 | from transformers import CLIPImageProcessor 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as checkpoint 16 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 17 | from .eva_clip import create_model_and_transforms, get_model_config 18 | import torch 19 | import torchvision 20 | import time 21 | 22 | from llava.utils import rank0_print 23 | 24 | 25 | class EvaViTWrapper(nn.Module): 26 | def __init__(self, vision_tower, args, delay_load=False): 27 | super().__init__() 28 | 29 | self.is_loaded = False 30 | self.vision_tower_name = vision_tower 31 | self.pretrained = args.vision_tower_pretrained 32 | self.args = args 33 | 34 | self.select_layer = args.mm_vision_select_layer 35 | if self.select_layer < -1: 36 | self.select_layer += 1 37 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 38 | 39 | self.model_config = get_model_config(self.vision_tower_name) 40 | 41 | if not delay_load: 42 | rank0_print(f"Loading vision tower: {vision_tower}") 43 | self.load_model() 44 | elif getattr(args, "unfreeze_mm_vision_tower", False): 45 | # TODO: better detector is needed. 46 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 47 | self.load_model() 48 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 49 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 50 | self.load_model() 51 | 52 | def load_model(self): 53 | rank0_print(f"Loading: {self.vision_tower_name}") 54 | rank0_print(f"Pretrained: {self.pretrained}") 55 | time_start = time.time() 56 | model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") 57 | time_end = time.time() 58 | rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") 59 | self.device = next(model.parameters()).device 60 | self.dtype = next(model.parameters()).dtype 61 | if self.device.type != "meta": 62 | model = model.to("cuda") 63 | self.vision_tower = model.visual 64 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 65 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 66 | self.resize_transform_size = resize_transform.size 67 | self.image_processor = CLIPImageProcessor.from_pretrained( 68 | "openai/clip-vit-large-patch14", 69 | crop_size=resize_transform.size, 70 | size={"shortest_edge": resize_transform.size}, 71 | image_mean=list(normalize_transform.mean), 72 | image_std=list(normalize_transform.std), 73 | ) 74 | rank0_print(f"Loaded image processor: {self.image_processor}") 75 | self.vision_tower.requires_grad_(False) 76 | self.is_loaded = True 77 | 78 | def feature_select(self, image_features): 79 | select_feature_type = self.select_feature 80 | 81 | # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 82 | # select_every_k_layer = len(image_features) // 4 83 | # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) 84 | # select_feature_type = select_feature_type.replace("slicefour_", "") 85 | # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 86 | # select_layers = [-1, -4, -7, -10, 6] 87 | # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) 88 | # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 89 | # else: 90 | # image_features = image_features[self.select_layer] 91 | 92 | if select_feature_type == "patch": 93 | image_features = image_features[:, 1:] 94 | elif select_feature_type == "cls_patch": 95 | image_features = image_features 96 | else: 97 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 98 | return image_features 99 | 100 | def train(self, mode=True): 101 | self.training = mode 102 | 103 | if self.is_loaded: 104 | self.vision_tower.eval() 105 | 106 | def forward(self, images): 107 | if type(images) is list: 108 | image_features = [] 109 | for image in images: 110 | image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) 111 | image_features = self.feature_select(image_features).to(self.dtype) 112 | image_features.append(image_features) 113 | else: 114 | image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) 115 | image_features = self.feature_select(image_features).to(self.dtype) 116 | 117 | return image_features 118 | 119 | @property 120 | def dummy_feature(self): 121 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 122 | 123 | @property 124 | def hidden_size(self): 125 | return self.model_config["vision_cfg"]["width"] 126 | 127 | @property 128 | def num_patches(self): 129 | return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 130 | 131 | @property 132 | def num_patches_per_side(self): 133 | return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] 134 | 135 | @property 136 | def config(self): 137 | return self.model_config 138 | 139 | @property 140 | def image_size(self): 141 | return self.model_config["vision_cfg"]["image_size"] 142 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import EVAEncoderWrapper 6 | from .factory import list_models, add_model_config, get_model_config 7 | 8 | from llava.utils import rank0_print 9 | 10 | 11 | class EvaClipVisionTower(nn.Module): 12 | def __init__(self, vision_tower, args, delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | self.vision_tower_name = vision_tower 17 | self.vision_tower_pretrained = args.vision_tower_pretrained 18 | self.config = get_model_config(vision_tower) 19 | 20 | if not delay_load: 21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") 22 | self.load_model() 23 | elif getattr(args, "unfreeze_mm_vision_tower", False): 24 | # TODO: better detector is needed. 25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 26 | self.load_model() 27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 29 | self.load_model() 30 | else: 31 | self.cfg_only = self.config 32 | 33 | def load_model(self, device_map=None): 34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}") 35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) 36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) 37 | rank0_print(f"Loaded image processor: {self.image_processor}") 38 | self.vision_tower.requires_grad_(False) 39 | self.is_loaded = True 40 | 41 | def forward(self, images): 42 | if type(images) is list: 43 | image_features = [] 44 | for image in images: 45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 49 | 50 | return image_features 51 | 52 | @property 53 | def dtype(self): 54 | return self.vision_tower.dtype 55 | 56 | @property 57 | def device(self): 58 | return self.vision_tower.device 59 | 60 | @property 61 | def hidden_size(self): 62 | return self.config["vision_cfg"]["width"] 63 | 64 | @property 65 | def num_patches(self): 66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 67 | 68 | @property 69 | def num_patches_per_side(self): 70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] 71 | 72 | @property 73 | def image_size(self): 74 | return self.config["vision_cfg"]["image_size"] 75 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | """ 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {"height": self.image_size, "width": self.image_size} 69 | 70 | @property 71 | def size(self): 72 | return {"shortest_edge": self.image_size} 73 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 13 | 14 | 15 | def _natural_key(string_): 16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 17 | 18 | 19 | def _rescan_model_configs(): 20 | global _MODEL_CONFIGS 21 | 22 | config_ext = (".json",) 23 | config_files = [] 24 | for config_path in _MODEL_CONFIG_PATHS: 25 | if config_path.is_file() and config_path.suffix in config_ext: 26 | config_files.append(config_path) 27 | elif config_path.is_dir(): 28 | for ext in config_ext: 29 | config_files.extend(config_path.glob(f"*{ext}")) 30 | 31 | for cf in config_files: 32 | with open(cf, "r", encoding="utf8") as f: 33 | model_cfg = json.load(f) 34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): 35 | _MODEL_CONFIGS[cf.stem] = model_cfg 36 | 37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 38 | 39 | 40 | _rescan_model_configs() # initial populate of model config registry 41 | 42 | 43 | def list_models(): 44 | """enumerate available model architectures based on config files""" 45 | return list(_MODEL_CONFIGS.keys()) 46 | 47 | 48 | def add_model_config(path): 49 | """add model config path or file and update registry""" 50 | if not isinstance(path, Path): 51 | path = Path(path) 52 | _MODEL_CONFIG_PATHS.append(path) 53 | _rescan_model_configs() 54 | 55 | 56 | def get_model_config(model_name): 57 | if model_name in _MODEL_CONFIGS: 58 | return deepcopy(_MODEL_CONFIGS[model_name]) 59 | else: 60 | return None 61 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/hf_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor 5 | from llava.utils import rank0_print 6 | 7 | 8 | class HFVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1) 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | try: 25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) 26 | except Exception as e: 27 | if "448" in self.vision_tower_name: 28 | image_size = 448 29 | # use image processor with conig 30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) 31 | else: 32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | rank0_print(f"Loaded image processor: {self.image_processor}") 34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") 35 | self.device = self.vision_tower.device 36 | self.dtype = self.vision_tower.dtype 37 | self.config = self.vision_tower.config 38 | 39 | if hasattr(self.vision_tower, "vision_model"): 40 | self.vision_tower = self.vision_tower.vision_model 41 | self.vision_tower.requires_grad_(False) 42 | # self.vision_tower.eval() 43 | self.is_loaded = True 44 | 45 | def feature_select(self, image_forward_outs): 46 | select_feature_type = self.select_feature 47 | 48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 51 | select_feature_type = select_feature_type.replace("slicefour_", "") 52 | else: 53 | image_features = image_forward_outs.hidden_states[self.select_layer] 54 | 55 | if select_feature_type == "patch": 56 | image_features = image_features[:, 1:] 57 | elif select_feature_type == "cls_patch": 58 | image_features = image_features 59 | else: 60 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 61 | return image_features 62 | 63 | def forward(self, images): 64 | if type(images) is list: 65 | image_features = [] 66 | for image in images: 67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 68 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 69 | image_features.append(image_feature) 70 | else: 71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 72 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 73 | 74 | return image_features 75 | 76 | @property 77 | def dummy_feature(self): 78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 79 | 80 | # @property 81 | # def dtype(self): 82 | # return self.vision_tower.dtype 83 | 84 | # @property 85 | # def device(self): 86 | # return self.vision_tower.device 87 | 88 | @property 89 | def hidden_size(self): 90 | try: 91 | _hidden_size = self.config.hidden_size 92 | except: 93 | _hidden_size = self.config.vision_config.hidden_size 94 | if "slicefour" in self.select_feature: 95 | _hidden_size *= 4 96 | return _hidden_size 97 | 98 | @property 99 | def num_patches(self): 100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 101 | if "cls_patch" in self.select_feature: 102 | _num_patches += 1 103 | return _num_patches 104 | 105 | @property 106 | def num_patches_per_side(self): 107 | return self.config.image_size // self.config.patch_size 108 | 109 | @property 110 | def image_size(self): 111 | return self.config.image_size 112 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/open_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor 4 | from llava.utils import rank0_print 5 | 6 | try: 7 | import open_clip 8 | import torchvision 9 | from open_clip.transformer import _expand_token 10 | except ImportError: 11 | print("OpenCLIP not installed") 12 | open_clip = None 13 | 14 | HIDDEN_SIZE_DICT = { 15 | "ViT-H-14-378-quickgelu": 1280, 16 | } 17 | 18 | 19 | class OpenCLIPVisionTower(nn.Module): 20 | def __init__(self, vision_tower, args, delay_load=False): 21 | super().__init__() 22 | 23 | self.is_loaded = False 24 | self.model_name = vision_tower.replace("open_clip_hub:", "") 25 | self.pretrained = args.vision_tower_pretrained 26 | self.select_layer = args.mm_vision_select_layer 27 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 28 | 29 | if not delay_load: 30 | rank0_print(f"Loading vision tower: {vision_tower}") 31 | self.load_model() 32 | elif getattr(args, "unfreeze_mm_vision_tower", False): 33 | # TODO: better detector is needed. 34 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 35 | self.load_model() 36 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 37 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 38 | self.load_model() 39 | 40 | def load_model(self, device_map="auto"): 41 | rank0_print(f"Loading OpenCLIP model: {self.model_name}") 42 | rank0_print(f"Pretrained: {self.pretrained}") 43 | vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") 44 | 45 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 46 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 47 | self.resize_transform_size = resize_transform.size # 224 or 384 48 | self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 49 | 50 | self.image_processor = CLIPImageProcessor.from_pretrained( 51 | "openai/clip-vit-large-patch14", 52 | crop_size=resize_transform.size, 53 | size={"shortest_edge": resize_transform.size}, 54 | image_mean=list(normalize_transform.mean), 55 | image_std=list(normalize_transform.std), 56 | ) 57 | rank0_print(f"Loaded image processor: {self.image_processor}") 58 | self.vision_tower = vision_tower.visual 59 | self.vision_tower.requires_grad_(False) 60 | 61 | self.is_loaded = True 62 | 63 | def feature_select(self, image_forward_outs): 64 | image_features = image_forward_outs[self.select_layer] 65 | if self.select_feature == "patch": 66 | image_features = image_features[:, 1:] 67 | elif self.select_feature == "cls_patch": 68 | image_features = image_features 69 | elif self.select_feature == "conv_flatten": 70 | image_features = image_features.flatten(2).transpose(1, 2) 71 | else: 72 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 73 | return image_features 74 | 75 | def forward_visual(self, x, output_hidden_states=False): 76 | if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): 77 | return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) 78 | else: 79 | 80 | def forward_openclip(self, x: torch.Tensor): 81 | features = [] 82 | x = self.conv1(x) # shape = [*, width, grid, grid] 83 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 84 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 85 | 86 | # class embeddings and positional embeddings 87 | x = torch.cat( 88 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], 89 | dim=1, 90 | ) 91 | # shape = [*, grid ** 2 + 1, width] 92 | x = x + self.positional_embedding.to(x.dtype) 93 | 94 | x = self.patch_dropout(x) 95 | x = self.ln_pre(x) 96 | 97 | x = x.permute(1, 0, 2) # NLD -> LND 98 | for r in self.transformer.resblocks: 99 | x = r(x, attn_mask=None) 100 | features.append(x) 101 | return features 102 | 103 | return forward_openclip(self.vision_tower, x) 104 | 105 | def forward(self, images): 106 | if type(images) is list: 107 | image_features = [] 108 | for image in images: 109 | image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) 110 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 111 | image_features.append(image_feature) 112 | else: 113 | image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) 114 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 115 | 116 | return image_features 117 | 118 | @property 119 | def dummy_feature(self): 120 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 121 | 122 | @property 123 | def dtype(self): 124 | if hasattr(self.vision_tower, "conv1"): 125 | return self.vision_tower.conv1.weight.dtype 126 | if hasattr(self.vision_tower, "trunk"): 127 | return self.vision_tower.trunk.patch_embed.proj.weight.dtype 128 | raise NotImplementedError 129 | 130 | @property 131 | def device(self): 132 | if hasattr(self.vision_tower, "conv1"): 133 | return self.vision_tower.conv1.weight.device 134 | if hasattr(self.vision_tower, "trunk"): 135 | return self.vision_tower.trunk.patch_embed.proj.weight.device 136 | raise NotImplementedError 137 | 138 | @property 139 | def config(self): 140 | return None 141 | 142 | @property 143 | def hidden_size(self): 144 | if self.model_name in HIDDEN_SIZE_DICT: 145 | return HIDDEN_SIZE_DICT[self.model_name] 146 | else: 147 | raise NotImplementedError 148 | 149 | @property 150 | def num_patches(self): 151 | image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] 152 | _num_patches = (image_size // self.patch_size) ** 2 153 | if "cls_patch" in self.select_feature: 154 | _num_patches += 1 155 | return _num_patches 156 | 157 | @property 158 | def image_size(self): 159 | return self.resize_transform_size 160 | 161 | @property 162 | def num_patches_per_side(self): 163 | return self.resize_transform_size // self.patch_size 164 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/model/vision_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 4 | 5 | import numpy as np 6 | 7 | class ProxyInitializer(nn.Module): 8 | def __init__(self, d_model, d_hidden, n): 9 | """ 10 | d_model: embedding dimension for both full tokens and proxy tokens 11 | """ 12 | super().__init__() 13 | self.Wk = nn.Linear(d_model, d_hidden, bias=False) 14 | 15 | self.proxyv_tokens = nn.Parameter(torch.randn((n, d_hidden))/torch.sqrt(torch.tensor(d_hidden*4))) 16 | 17 | def forward(self, full_tokens, compress_reduce_factor, single_crop_len): 18 | """ 19 | full_tokens: Tensor of shape (batch_size, N, d_model) 20 | proxyv_tokens: Tensor of shape (batch_size, n, d_model) 21 | 22 | Returns: 23 | proxyv_out: updated proxy tokens (batch_size, n, d_model) 24 | attn: attention matrix from proxy->full (batch_size, n, N) 25 | """ 26 | proxyv_tokens = self.proxyv_tokens.unsqueeze(0).repeat(full_tokens.shape[0], 1, 1) 27 | Q = proxyv_tokens 28 | K = self.Wk(full_tokens) 29 | V = full_tokens 30 | 31 | d_model = Q.shape[-1] 32 | attn_logits = torch.bmm(Q, K.transpose(1, 2)) / (d_model ** 0.5) 33 | 34 | attn = nn.functional.softmax(attn_logits, dim=-1) 35 | 36 | attn_logits_T = attn_logits.transpose(1, 2) 37 | attn_T = nn.functional.softmax(attn_logits_T, dim=-1) 38 | 39 | # Update the proxy tokens by pooling full tokens 40 | proxyv_out = torch.bmm(attn, V) 41 | return proxyv_out, attn_T 42 | 43 | def splat_proxyv_tokens(proxyv_tokens, attn): 44 | return torch.bmm(attn, proxyv_tokens) 45 | 46 | 47 | class VisionMLP(nn.Module): 48 | def __init__(self, config, intermediate_size): 49 | super().__init__() 50 | self.context_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) 51 | self.input_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) 52 | self.proj = nn.Sequential( 53 | nn.Linear(intermediate_size*2, intermediate_size, bias=False), 54 | nn.SiLU(), 55 | nn.Linear(intermediate_size, config.hidden_size, bias=False) 56 | ) 57 | self.layernorm_post = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 58 | 59 | def forward(self, image_full, image_compress, compress_reduce_factor, per_crop_token_len=576, learn_proxy=False, proxyv_attn=None): 60 | side_len_full = int(per_crop_token_len**0.5) 61 | side_len_compress = side_len_full // compress_reduce_factor 62 | 63 | num_image_crops = image_full.shape[1]//per_crop_token_len 64 | bs = image_full.shape[0] 65 | 66 | if learn_proxy: 67 | image_full = image_full.view(bs*num_image_crops, side_len_full*side_len_full, -1) 68 | image_compress = image_compress.view(bs*num_image_crops, side_len_compress*side_len_compress, -1) 69 | image_compress = splat_proxyv_tokens(image_compress, proxyv_attn) 70 | image_compress = self.context_proj(image_compress) 71 | else: 72 | image_full = image_full.view(bs*num_image_crops, side_len_full, side_len_full, -1) 73 | image_compress = image_compress.view(bs*num_image_crops, side_len_compress, side_len_compress, -1) 74 | image_compress = self.context_proj(image_compress) 75 | image_compress = image_compress.repeat_interleave(compress_reduce_factor, 1).repeat_interleave(compress_reduce_factor, 2) 76 | residual = image_full 77 | image_full = self.input_proj(image_full) 78 | image_full = torch.cat([image_full, image_compress], -1) 79 | image_full = self.layernorm_post(self.proj(image_full) + residual) 80 | 81 | image_full = image_full.view(bs, num_image_crops*side_len_full*side_len_full, -1) 82 | 83 | return image_full 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | padding_mask: Optional[torch.Tensor] = None, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | if output_attentions: 27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 28 | 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) 34 | 35 | kv_seq_len = key_states.shape[-2] 36 | if past_key_value is not None: 37 | kv_seq_len += past_key_value[0].shape[-2] 38 | 39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 41 | 42 | if past_key_value is not None: 43 | # reuse k, v 44 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 45 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 46 | 47 | past_key_value = (key_states, value_states) if use_cache else None 48 | 49 | # repeat k/v heads if n_kv_heads < n_heads 50 | key_states = repeat_kv(key_states, self.num_key_value_groups) 51 | value_states = repeat_kv(value_states, self.num_key_value_groups) 52 | 53 | # Transform the data into the format required by flash attention 54 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 56 | key_padding_mask = attention_mask 57 | 58 | if key_padding_mask is None: 59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 61 | max_s = q_len 62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 63 | output = output.view(bsz, q_len, -1) 64 | else: 65 | qkv = qkv.reshape(bsz, q_len, -1) 66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 70 | output = pad_input(output_unpad, indices, bsz, q_len) 71 | 72 | return self.o_proj(output), None, past_key_value 73 | 74 | 75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 76 | # requires the attention mask to be the same as the key_padding_mask 77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 78 | # [bsz, seq_len] 79 | return attention_mask 80 | 81 | 82 | def replace_llama_attn_with_flash_attn(): 83 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 84 | if cuda_major < 8: 85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") 86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 88 | -------------------------------------------------------------------------------- /llava/train/llava_trainer_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from llava.train.llava_trainer import LLaVATrainer 5 | 6 | 7 | class LLaVAEvalTrainer(LLaVATrainer): 8 | def evaluate(self, evaluate_args): 9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ 10 | --model {evaluate_args.model} \ 11 | --model_args {evaluate_args.model_args} \ 12 | --tasks {evaluate_args.task_names} \ 13 | --batch_size {evaluate_args.batch_size} \ 14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \ 15 | --output_path {evaluate_args.output_path}" 16 | if evaluate_args.limit: 17 | cmd += f" --limit {evaluate_args.limit}" 18 | if evaluate_args.num_fewshot: 19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}" 20 | if evaluate_args.gen_kwargs != "": 21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" 22 | if evaluate_args.log_samples: 23 | cmd += f" --log_samples" 24 | else: 25 | assert False, "Please log samples so that the result can be parsed" 26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True) 27 | try: 28 | result_file_index_start = results.stdout.index("Saved samples to ") 29 | result_file_index_end = results.stdout.index(f".json") 30 | result_file_index_start += len("Saved samples to ") 31 | file = results.stdout[result_file_index_start:result_file_index_end] 32 | except: 33 | result_file_index_start = results.stderr.index("Saved samples to ") 34 | result_file_index_end = results.stderr.index(f".json") 35 | result_file_index_start += len("Saved samples to ") 36 | file = results.stderr[result_file_index_start:result_file_index_end] 37 | file = file.split("/")[:-1] 38 | file = "/".join(file) + "/results.json" 39 | with open(file, "r") as f: 40 | lmms_eval_results = json.load(f) 41 | result_dict = {} 42 | tasks_list = evaluate_args.task_names.split(",") 43 | for task in tasks_list: 44 | task_results = lmms_eval_results["results"][task] 45 | for k, v in task_results.items(): 46 | if k != "alias" and "stderr" not in k: 47 | metric = k.split(",")[0] 48 | result_dict[f"{task}_{metric}"] = v 49 | return result_dict 50 | 51 | """def evaluate(self, evaluate_args): 52 | initialize_tasks() 53 | tasks_list = evaluate_args.task_names.split(",") 54 | result_dict = {} 55 | results = evaluator.simple_evaluate( 56 | model=evaluate_args.model, 57 | model_args=evaluate_args.model_args, 58 | tasks=tasks_list, 59 | num_fewshot=evaluate_args.num_fewshot, 60 | batch_size=evaluate_args.batch_size, 61 | device=evaluate_args.device, 62 | limit=evaluate_args.limit, 63 | check_integrity=evaluate_args.check_integrity, 64 | show_task_to_terminal=evaluate_args.show_task_to_terminal, 65 | log_samples=evaluate_args.log_samples, 66 | gen_kwargs=evaluate_args.gen_kwargs, 67 | cli_args=evaluate_args, 68 | ) 69 | for task in tasks_list: 70 | task_results = results["results"][task] 71 | for k,v in task_results.items(): 72 | if k != "alias" and "stderr" not in k: 73 | metric = k.split(",")[0] 74 | result_dict[f"{task}_{metric}"] = v 75 | 76 | return result_dict""" 77 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train() 5 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | import requests 9 | 10 | from llava.constants import LOGDIR 11 | 12 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content." 14 | 15 | handler = None 16 | 17 | import torch.distributed as dist 18 | 19 | try: 20 | import av 21 | from decord import VideoReader, cpu 22 | except ImportError: 23 | print("Please install pyav to use video processing functions.") 24 | 25 | def process_video_with_decord(video_file, data_args): 26 | vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) 27 | total_frame_num = len(vr) 28 | video_time = total_frame_num / vr.get_avg_fps() 29 | avg_fps = round(vr.get_avg_fps() / data_args.video_fps) 30 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)] 31 | frame_time = [i/avg_fps for i in frame_idx] 32 | 33 | 34 | if data_args.frames_upbound > 0: 35 | if len(frame_idx) > data_args.frames_upbound or data_args.force_sample: 36 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) 37 | frame_idx = uniform_sampled_frames.tolist() 38 | frame_time = [i/vr.get_avg_fps() for i in frame_idx] 39 | 40 | video = vr.get_batch(frame_idx).asnumpy() 41 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) 42 | 43 | num_frames_to_sample = num_frames = len(frame_idx) 44 | # https://github.com/dmlc/decord/issues/208 45 | vr.seek(0) 46 | return video, video_time, frame_time, num_frames_to_sample 47 | 48 | def process_video_with_pyav(video_file, data_args): 49 | container = av.open(video_file) 50 | # !!! This is the only difference. Using auto threading 51 | container.streams.video[0].thread_type = "AUTO" 52 | 53 | video_frames = [] 54 | for packet in container.demux(): 55 | if packet.stream.type == 'video': 56 | for frame in packet.decode(): 57 | video_frames.append(frame) 58 | total_frame_num = len(video_frames) 59 | video_time = video_frames[-1].time 60 | avg_fps = round(total_frame_num / video_time / data_args.video_fps) 61 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)] 62 | 63 | if data_args.frames_upbound > 0: 64 | if len(frame_idx) > data_args.frames_upbound: 65 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) 66 | frame_idx = uniform_sampled_frames.tolist() 67 | 68 | 69 | frames = [video_frames[i] for i in frame_idx] 70 | return np.stack([x.to_ndarray(format="rgb24") for x in frames]) 71 | 72 | 73 | def rank0_print(*args): 74 | if dist.is_initialized(): 75 | if dist.get_rank() == 0: 76 | print(f"Rank {dist.get_rank()}: ", *args) 77 | else: 78 | print(*args) 79 | 80 | 81 | def rank_print(*args): 82 | if dist.is_initialized(): 83 | print(f"Rank {dist.get_rank()}: ", *args) 84 | else: 85 | print(*args) 86 | 87 | def build_logger(logger_name, logger_filename): 88 | global handler 89 | 90 | formatter = logging.Formatter( 91 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 92 | datefmt="%Y-%m-%d %H:%M:%S", 93 | ) 94 | 95 | # Set the format of root handlers 96 | if not logging.getLogger().handlers: 97 | logging.basicConfig(level=logging.INFO) 98 | logging.getLogger().handlers[0].setFormatter(formatter) 99 | 100 | # Redirect stdout and stderr to loggers 101 | stdout_logger = logging.getLogger("stdout") 102 | stdout_logger.setLevel(logging.INFO) 103 | sl = StreamToLogger(stdout_logger, logging.INFO) 104 | sys.stdout = sl 105 | 106 | stderr_logger = logging.getLogger("stderr") 107 | stderr_logger.setLevel(logging.ERROR) 108 | sl = StreamToLogger(stderr_logger, logging.ERROR) 109 | sys.stderr = sl 110 | 111 | # Get logger 112 | logger = logging.getLogger(logger_name) 113 | logger.setLevel(logging.INFO) 114 | 115 | # Add a file handler for all loggers 116 | if handler is None: 117 | os.makedirs(LOGDIR, exist_ok=True) 118 | filename = os.path.join(LOGDIR, logger_filename) 119 | handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True) 120 | handler.setFormatter(formatter) 121 | 122 | for name, item in logging.root.manager.loggerDict.items(): 123 | if isinstance(item, logging.Logger): 124 | item.addHandler(handler) 125 | 126 | return logger 127 | 128 | 129 | class StreamToLogger(object): 130 | """ 131 | Fake file-like stream object that redirects writes to a logger instance. 132 | """ 133 | 134 | def __init__(self, logger, log_level=logging.INFO): 135 | self.terminal = sys.stdout 136 | self.logger = logger 137 | self.log_level = log_level 138 | self.linebuf = "" 139 | 140 | def __getattr__(self, attr): 141 | return getattr(self.terminal, attr) 142 | 143 | def write(self, buf): 144 | temp_linebuf = self.linebuf + buf 145 | self.linebuf = "" 146 | for line in temp_linebuf.splitlines(True): 147 | # From the io.TextIOWrapper docs: 148 | # On output, if newline is None, any '\n' characters written 149 | # are translated to the system default line separator. 150 | # By default sys.stdout.write() expects '\n' newlines and then 151 | # translates them so this is still cross platform. 152 | if line[-1] == "\n": 153 | self.logger.log(self.log_level, line.rstrip()) 154 | else: 155 | self.linebuf += line 156 | 157 | def flush(self): 158 | if self.linebuf != "": 159 | self.logger.log(self.log_level, self.linebuf.rstrip()) 160 | self.linebuf = "" 161 | 162 | 163 | def disable_torch_init(): 164 | """ 165 | Disable the redundant torch default initialization to accelerate model creation. 166 | """ 167 | import torch 168 | 169 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 170 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 171 | 172 | 173 | def violates_moderation(text): 174 | """ 175 | Check whether the text violates OpenAI moderation API. 176 | """ 177 | url = "https://api.openai.com/v1/moderations" 178 | headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 179 | text = text.replace("\n", "") 180 | data = "{" + '"input": ' + f'"{text}"' + "}" 181 | data = data.encode("utf-8") 182 | try: 183 | ret = requests.post(url, headers=headers, data=data, timeout=5) 184 | flagged = ret.json()["results"][0]["flagged"] 185 | except requests.exceptions.RequestException as e: 186 | print(f"######################### Moderation Error: {e} #########################") 187 | flagged = False 188 | except KeyError as e: 189 | print(f"######################### Moderation Error: {e} #########################") 190 | flagged = False 191 | 192 | return flagged 193 | 194 | 195 | def pretty_print_semaphore(semaphore): 196 | if semaphore is None: 197 | return "None" 198 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 199 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "proxyv" 10 | version = "1.0.0" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: Apache Software License", 16 | ] 17 | 18 | [project.optional-dependencies] 19 | standalone = [ 20 | "shortuuid", 21 | "httpx==0.24.0", 22 | "einops", 23 | "ftfy", 24 | ] 25 | 26 | 27 | train = [ 28 | "proxyv[standalone]", 29 | "numpy==1.26.4", 30 | "open_clip_torch", 31 | "fastapi", 32 | "markdown2[all]", 33 | "requests", 34 | "sentencepiece", 35 | "torch==2.1.2", 36 | "torchvision==0.16.2", 37 | "uvicorn==0.32.0", 38 | "wandb", 39 | "deepspeed==0.14.2", 40 | "peft==0.4.0", 41 | "accelerate==0.34.2", 42 | "tokenizers==0.19.1", 43 | "transformers==4.41.2", 44 | "bitsandbytes==0.41.0", 45 | "scikit-learn==1.2.2", 46 | "sentencepiece~=0.1.99", 47 | "einops==0.6.1", 48 | "einops-exts==0.0.4", 49 | "gradio_client==0.2.9", 50 | "urllib3<=2.0.0", 51 | "datasets==2.16.1", 52 | "pydantic==2.10.6", 53 | "timm", 54 | "triton==2.1.0", 55 | "hf_transfer", 56 | "opencv-python", 57 | "av", 58 | "decord", 59 | "tyro", 60 | "scipy", 61 | ] 62 | 63 | [tool.setuptools.packages.find] 64 | include = ["llava*"] 65 | exclude = [ 66 | "assets*", 67 | "benchmark*", 68 | "docs", 69 | "dist*", 70 | "playground*", 71 | "scripts*", 72 | "train_scripts", 73 | "tests*", 74 | "checkpoints*", 75 | "project_checkpoints*", 76 | "debug_checkpoints*", 77 | "mlx_configs*", 78 | "wandb*", 79 | "notebooks*", 80 | ] 81 | 82 | [tool.wheel] 83 | exclude = [ 84 | "assets*", 85 | "benchmark*", 86 | "docs", 87 | "dist*", 88 | "playground*", 89 | "scripts*", 90 | "train_scripts", 91 | "tests*", 92 | "checkpoints*", 93 | "project_checkpoints*", 94 | "debug_checkpoints*", 95 | "mlx_configs*", 96 | "wandb*", 97 | "notebooks*", 98 | ] 99 | -------------------------------------------------------------------------------- /scripts/finetune/finetune_vicuna7b_baseline.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="lmsys/vicuna-7b-v1.5" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | RANK=${RANK:-0} # srun env node rank 24 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 25 | MASTER_PORT=${MASTER_PORT:-10086} 26 | 27 | PROMPT_VERSION="v1" 28 | 29 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_baseline" 30 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 31 | 32 | RUN_NAME="proxyv_vicuna7b_finetune_baseline" 33 | echo "RUN_NAME: ${RUN_NAME}" 34 | 35 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 36 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 37 | 38 | LOG_DIR=checkpoints/${RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=4 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --pretrain_mm_mlp_adapter="./checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \ 55 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model" \ 56 | --vision_tower ${VISION_MODEL_VERSION} \ 57 | --mm_projector_type mlp2x_gelu \ 58 | --mm_vision_select_layer -2 \ 59 | --mm_use_im_start_end False \ 60 | --mm_use_im_patch_token False \ 61 | --group_by_modality_length True \ 62 | --max_num_image_crops 5 \ 63 | --per_crop_token_len 576 \ 64 | --proxyv False \ 65 | --image_aspect_ratio anyres \ 66 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ 67 | --bf16 True \ 68 | --run_name $RUN_NAME \ 69 | --output_dir "./checkpoints/${RUN_NAME}" \ 70 | --num_train_epochs 1 \ 71 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 72 | --per_device_eval_batch_size 4 \ 73 | --gradient_accumulation_steps $GRADIENT_ACC \ 74 | --evaluation_strategy "no" \ 75 | --save_strategy "steps" \ 76 | --save_steps 500 \ 77 | --save_total_limit 1 \ 78 | --learning_rate 4e-5 \ 79 | --weight_decay 0. \ 80 | --warmup_ratio 0.03 \ 81 | --lr_scheduler_type "cosine" \ 82 | --logging_steps 1 \ 83 | --tf32 True \ 84 | --model_max_length 4096 \ 85 | --gradient_checkpointing True \ 86 | --dataloader_num_workers 16 \ 87 | --lazy_preprocess True \ 88 | --report_to wandb \ 89 | --run_name $RUN_NAME \ 90 | --torch_compile True \ 91 | --torch_compile_backend "inductor" \ 92 | --dataloader_drop_last True \ 93 | --attn_implementation sdpa \ 94 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/finetune/finetune_vicuna7b_proxyv.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="lmsys/vicuna-7b-v1.5" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | RANK=${RANK:-0} # srun env node rank 24 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 25 | MASTER_PORT=${MASTER_PORT:-10086} 26 | 27 | PROMPT_VERSION="v1" 28 | 29 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_proxyv_layer12_lr3e4" 30 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 31 | 32 | RUN_NAME="proxyv_vicuna7b_finetune_proxyv_layer12" 33 | echo "RUN_NAME: ${RUN_NAME}" 34 | 35 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 36 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 37 | 38 | LOG_DIR=checkpoints/${RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=4 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --pretrain_mm_mlp_adapter="./checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \ 55 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model,mm_vision_mlp" \ 56 | --vision_tower ${VISION_MODEL_VERSION} \ 57 | --mm_projector_type mlp2x_gelu \ 58 | --mm_vision_select_layer -2 \ 59 | --mm_use_im_start_end False \ 60 | --mm_use_im_patch_token False \ 61 | --group_by_modality_length True \ 62 | --max_num_image_crops 5 \ 63 | --per_crop_token_len 576 \ 64 | --proxyv_reduce_factor 4 \ 65 | --proxyv True \ 66 | --proxyv_start_layer 12 \ 67 | --image_aspect_ratio anyres \ 68 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ 69 | --mm_patch_merge_type spatial_unpad \ 70 | --bf16 True \ 71 | --run_name $RUN_NAME \ 72 | --output_dir "./checkpoints/${RUN_NAME}" \ 73 | --num_train_epochs 1 \ 74 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 75 | --per_device_eval_batch_size 4 \ 76 | --gradient_accumulation_steps $GRADIENT_ACC \ 77 | --evaluation_strategy "no" \ 78 | --save_strategy "steps" \ 79 | --save_steps 100 \ 80 | --save_total_limit 1 \ 81 | --learning_rate 4e-5 \ 82 | --weight_decay 0. \ 83 | --warmup_ratio 0.03 \ 84 | --lr_scheduler_type "cosine" \ 85 | --logging_steps 1 \ 86 | --tf32 True \ 87 | --model_max_length 4096 \ 88 | --gradient_checkpointing True \ 89 | --dataloader_num_workers 16 \ 90 | --lazy_preprocess True \ 91 | --report_to wandb \ 92 | --run_name $RUN_NAME \ 93 | --torch_compile True \ 94 | --torch_compile_backend "inductor" \ 95 | --dataloader_drop_last True \ 96 | --attn_implementation sdpa \ 97 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_llama8b_baseline.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="meta-llama/Meta-Llama-3-8B-Instruct" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=llava_llama_3 34 | 35 | BASE_RUN_NAME="proxyv_llama3_8b_pretrain_baseline" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=8 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter" \ 55 | --mm_vision_select_layer -2 \ 56 | --max_num_image_crops 1 \ 57 | --per_crop_token_len 576 \ 58 | --proxyv False \ 59 | --mm_projector_type mlp2x_gelu \ 60 | --mm_use_im_start_end False \ 61 | --mm_use_im_patch_token False \ 62 | --image_aspect_ratio pad \ 63 | --bf16 True \ 64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 65 | --num_train_epochs 1 \ 66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 67 | --per_device_eval_batch_size 4 \ 68 | --gradient_accumulation_steps $GRADIENT_ACC \ 69 | --evaluation_strategy "no" \ 70 | --save_strategy "steps" \ 71 | --save_total_limit 1 \ 72 | --save_steps 1000 \ 73 | --learning_rate 1e-3 \ 74 | --weight_decay 0. \ 75 | --warmup_ratio 0.03 \ 76 | --lr_scheduler_type "cosine" \ 77 | --logging_steps 1 \ 78 | --tf32 True \ 79 | --model_max_length 2048 \ 80 | --gradient_checkpointing True \ 81 | --dataloader_num_workers 4 \ 82 | --lazy_preprocess True \ 83 | --report_to wandb \ 84 | --run_name $BASE_RUN_NAME \ 85 | --attn_implementation sdpa \ 86 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_llama8b_proxyv.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="meta-llama/Meta-Llama-3-8B-Instruct" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=llava_llama_3 34 | 35 | BASE_RUN_NAME="proxyv_llama3_8b_pretrain_proxyv_layer16_lr3e4" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=8 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \ 55 | --mm_vision_mlp_lr 3e-4 \ 56 | --mm_vision_select_layer -2 \ 57 | --max_num_image_crops 1 \ 58 | --per_crop_token_len 576 \ 59 | --proxyv_reduce_factor 4 \ 60 | --proxyv True \ 61 | --proxyv_start_layer 16 \ 62 | --mm_projector_type mlp2x_gelu \ 63 | --mm_use_im_start_end False \ 64 | --mm_use_im_patch_token False \ 65 | --image_aspect_ratio pad \ 66 | --bf16 True \ 67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 68 | --num_train_epochs 1 \ 69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 70 | --per_device_eval_batch_size 4 \ 71 | --gradient_accumulation_steps $GRADIENT_ACC \ 72 | --evaluation_strategy "no" \ 73 | --save_strategy "steps" \ 74 | --save_total_limit 1 \ 75 | --save_steps 1000 \ 76 | --learning_rate 1e-3 \ 77 | --weight_decay 0. \ 78 | --warmup_ratio 0.03 \ 79 | --lr_scheduler_type "cosine" \ 80 | --logging_steps 1 \ 81 | --tf32 True \ 82 | --model_max_length 2048 \ 83 | --gradient_checkpointing True \ 84 | --dataloader_num_workers 4 \ 85 | --lazy_preprocess True \ 86 | --report_to wandb \ 87 | --run_name $BASE_RUN_NAME \ 88 | --attn_implementation sdpa \ 89 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_phi3b_baseline.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="microsoft/Phi-3-mini-4k-instruct" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=phi3_instruct 34 | 35 | BASE_RUN_NAME="proxyv_phi3_3b_pretrain_baseline" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=16 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter" \ 55 | --mm_vision_select_layer -2 \ 56 | --max_num_image_crops 1 \ 57 | --per_crop_token_len 576 \ 58 | --proxyv False \ 59 | --mm_projector_type mlp2x_gelu \ 60 | --mm_use_im_start_end False \ 61 | --mm_use_im_patch_token False \ 62 | --image_aspect_ratio pad \ 63 | --bf16 True \ 64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 65 | --num_train_epochs 1 \ 66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 67 | --per_device_eval_batch_size 4 \ 68 | --gradient_accumulation_steps $GRADIENT_ACC \ 69 | --evaluation_strategy "no" \ 70 | --save_strategy "steps" \ 71 | --save_total_limit 1 \ 72 | --save_steps 1000 \ 73 | --learning_rate 1e-3 \ 74 | --weight_decay 0. \ 75 | --warmup_ratio 0.03 \ 76 | --lr_scheduler_type "cosine" \ 77 | --logging_steps 1 \ 78 | --tf32 True \ 79 | --model_max_length 2048 \ 80 | --gradient_checkpointing True \ 81 | --dataloader_num_workers 4 \ 82 | --lazy_preprocess True \ 83 | --report_to wandb \ 84 | --run_name $BASE_RUN_NAME \ 85 | --attn_implementation sdpa \ 86 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_phi3b_proxyv.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="microsoft/Phi-3-mini-4k-instruct" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=phi3_instruct 34 | 35 | BASE_RUN_NAME="proxyv_phi3_3b_pretrain_proxyv_layer16_lr5e4" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=16 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \ 55 | --mm_vision_mlp_lr 5e-4 \ 56 | --mm_vision_select_layer -2 \ 57 | --max_num_image_crops 1 \ 58 | --per_crop_token_len 576 \ 59 | --proxyv_reduce_factor 4 \ 60 | --proxyv True \ 61 | --proxyv_start_layer 16 \ 62 | --mm_projector_type mlp2x_gelu \ 63 | --mm_use_im_start_end False \ 64 | --mm_use_im_patch_token False \ 65 | --image_aspect_ratio pad \ 66 | --bf16 True \ 67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 68 | --num_train_epochs 1 \ 69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 70 | --per_device_eval_batch_size 4 \ 71 | --gradient_accumulation_steps $GRADIENT_ACC \ 72 | --evaluation_strategy "no" \ 73 | --save_strategy "steps" \ 74 | --save_total_limit 1 \ 75 | --save_steps 1000 \ 76 | --learning_rate 1e-3 \ 77 | --weight_decay 0. \ 78 | --warmup_ratio 0.03 \ 79 | --lr_scheduler_type "cosine" \ 80 | --logging_steps 1 \ 81 | --tf32 True \ 82 | --model_max_length 2048 \ 83 | --gradient_checkpointing True \ 84 | --dataloader_num_workers 4 \ 85 | --lazy_preprocess True \ 86 | --report_to wandb \ 87 | --run_name $BASE_RUN_NAME \ 88 | --attn_implementation sdpa \ 89 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_qwen7b_baseline.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=qwen_1_5 34 | 35 | BASE_RUN_NAME="proxyv_qwen2_7b_pretrain_baseline" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=16 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter" \ 55 | --mm_vision_select_layer -2 \ 56 | --max_num_image_crops 1 \ 57 | --per_crop_token_len 576 \ 58 | --proxyv False \ 59 | --mm_projector_type mlp2x_gelu \ 60 | --mm_use_im_start_end False \ 61 | --mm_use_im_patch_token False \ 62 | --image_aspect_ratio pad \ 63 | --bf16 True \ 64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 65 | --num_train_epochs 1 \ 66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 67 | --per_device_eval_batch_size 4 \ 68 | --gradient_accumulation_steps $GRADIENT_ACC \ 69 | --evaluation_strategy "no" \ 70 | --save_strategy "steps" \ 71 | --save_total_limit 1 \ 72 | --save_steps 1000 \ 73 | --learning_rate 1e-3 \ 74 | --weight_decay 0. \ 75 | --warmup_ratio 0.03 \ 76 | --lr_scheduler_type "cosine" \ 77 | --logging_steps 1 \ 78 | --tf32 True \ 79 | --model_max_length 2048 \ 80 | --gradient_checkpointing True \ 81 | --dataloader_num_workers 4 \ 82 | --lazy_preprocess True \ 83 | --report_to wandb \ 84 | --run_name $BASE_RUN_NAME \ 85 | --attn_implementation sdpa \ 86 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_qwen7b_proxyv.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="Qwen/Qwen2-7B-Instruct" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=qwen_1_5 34 | 35 | BASE_RUN_NAME="proxyv_qwen2_7b_pretrain_proxyv_layer16_lr8e4" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=16 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \ 55 | --mm_vision_mlp_lr 8e-4 \ 56 | --mm_vision_select_layer -2 \ 57 | --max_num_image_crops 1 \ 58 | --per_crop_token_len 576 \ 59 | --proxyv_reduce_factor 4 \ 60 | --proxyv True \ 61 | --proxyv_start_layer 16 \ 62 | --mm_projector_type mlp2x_gelu \ 63 | --mm_use_im_start_end False \ 64 | --mm_use_im_patch_token False \ 65 | --image_aspect_ratio pad \ 66 | --bf16 True \ 67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 68 | --num_train_epochs 1 \ 69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 70 | --per_device_eval_batch_size 4 \ 71 | --gradient_accumulation_steps $GRADIENT_ACC \ 72 | --evaluation_strategy "no" \ 73 | --save_strategy "steps" \ 74 | --save_total_limit 1 \ 75 | --save_steps 1000 \ 76 | --learning_rate 1e-3 \ 77 | --weight_decay 0. \ 78 | --warmup_ratio 0.03 \ 79 | --lr_scheduler_type "cosine" \ 80 | --logging_steps 1 \ 81 | --tf32 True \ 82 | --model_max_length 2048 \ 83 | --gradient_checkpointing True \ 84 | --dataloader_num_workers 4 \ 85 | --lazy_preprocess True \ 86 | --report_to wandb \ 87 | --run_name $BASE_RUN_NAME \ 88 | --attn_implementation sdpa \ 89 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_vicuna13b_baselin.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="lmsys/vicuna-13b-v1.5" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=v1 34 | 35 | BASE_RUN_NAME="proxyv_vicuna13b_pretrain_baseline" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=8 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter" \ 55 | --mm_vision_select_layer -2 \ 56 | --max_num_image_crops 1 \ 57 | --per_crop_token_len 576 \ 58 | --proxyv False \ 59 | --mm_projector_type mlp2x_gelu \ 60 | --mm_use_im_start_end False \ 61 | --mm_use_im_patch_token False \ 62 | --image_aspect_ratio pad \ 63 | --bf16 True \ 64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 65 | --num_train_epochs 1 \ 66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 67 | --per_device_eval_batch_size 4 \ 68 | --gradient_accumulation_steps $GRADIENT_ACC \ 69 | --evaluation_strategy "no" \ 70 | --save_strategy "steps" \ 71 | --save_total_limit 1 \ 72 | --save_steps 1000 \ 73 | --learning_rate 1e-3 \ 74 | --weight_decay 0. \ 75 | --warmup_ratio 0.03 \ 76 | --lr_scheduler_type "cosine" \ 77 | --logging_steps 1 \ 78 | --tf32 True \ 79 | --model_max_length 2048 \ 80 | --gradient_checkpointing True \ 81 | --dataloader_num_workers 4 \ 82 | --lazy_preprocess True \ 83 | --report_to wandb \ 84 | --run_name $BASE_RUN_NAME \ 85 | --attn_implementation sdpa \ 86 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_vicuna13b_proxyv.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="lmsys/vicuna-13b-v1.5" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=v1 34 | 35 | BASE_RUN_NAME="proxyv_vicuna13b_pretrain_proxyv_layer16_lr3e4" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=8 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \ 55 | --mm_vision_mlp_lr 3e-4 \ 56 | --mm_vision_select_layer -2 \ 57 | --max_num_image_crops 1 \ 58 | --per_crop_token_len 576 \ 59 | --proxyv_reduce_factor 4 \ 60 | --proxyv True \ 61 | --proxyv_start_layer 16 \ 62 | --mm_projector_type mlp2x_gelu \ 63 | --mm_use_im_start_end False \ 64 | --mm_use_im_patch_token False \ 65 | --image_aspect_ratio pad \ 66 | --bf16 True \ 67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 68 | --num_train_epochs 1 \ 69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 70 | --per_device_eval_batch_size 4 \ 71 | --gradient_accumulation_steps $GRADIENT_ACC \ 72 | --evaluation_strategy "no" \ 73 | --save_strategy "steps" \ 74 | --save_total_limit 1 \ 75 | --save_steps 1000 \ 76 | --learning_rate 1e-3 \ 77 | --weight_decay 0. \ 78 | --warmup_ratio 0.03 \ 79 | --lr_scheduler_type "cosine" \ 80 | --logging_steps 1 \ 81 | --tf32 True \ 82 | --model_max_length 2048 \ 83 | --gradient_checkpointing True \ 84 | --dataloader_num_workers 4 \ 85 | --lazy_preprocess True \ 86 | --report_to wandb \ 87 | --run_name $BASE_RUN_NAME \ 88 | --attn_implementation sdpa \ 89 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_vicuna7b_baseline.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="lmsys/vicuna-7b-v1.5" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=v1 34 | 35 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_baseline" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=16 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter" \ 55 | --mm_vision_select_layer -2 \ 56 | --max_num_image_crops 1 \ 57 | --per_crop_token_len 576 \ 58 | --proxyv False \ 59 | --mm_projector_type mlp2x_gelu \ 60 | --mm_use_im_start_end False \ 61 | --mm_use_im_patch_token False \ 62 | --image_aspect_ratio pad \ 63 | --bf16 True \ 64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 65 | --num_train_epochs 1 \ 66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 67 | --per_device_eval_batch_size 4 \ 68 | --gradient_accumulation_steps $GRADIENT_ACC \ 69 | --evaluation_strategy "no" \ 70 | --save_strategy "steps" \ 71 | --save_total_limit 1 \ 72 | --save_steps 1000 \ 73 | --learning_rate 1e-3 \ 74 | --weight_decay 0. \ 75 | --warmup_ratio 0.03 \ 76 | --lr_scheduler_type "cosine" \ 77 | --logging_steps 1 \ 78 | --tf32 True \ 79 | --model_max_length 2048 \ 80 | --gradient_checkpointing True \ 81 | --dataloader_num_workers 4 \ 82 | --lazy_preprocess True \ 83 | --report_to wandb \ 84 | --run_name $BASE_RUN_NAME \ 85 | --attn_implementation sdpa \ 86 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/pretrain/pretrain_vicuna7b_proxyv.sh: -------------------------------------------------------------------------------- 1 | !/usr/bin/env bash 2 | set -x 3 | T=`date +%Y%m%d_%H%M%S` 4 | export OMP_NUM_THREADS=8 5 | export NCCL_IB_DISABLE=0 6 | export NCCL_IB_GID_INDEX=3 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | export WANDB_RESUME="allow" 10 | 11 | export WANDB_API_KEY="" 12 | export WANDB_PROJECT="proxyv" 13 | 14 | LLM_VERSION="lmsys/vicuna-7b-v1.5" 15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" 16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" 17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 18 | 19 | DATA_PATH="TODO" 20 | IMAGE_FOLDER="TODO" 21 | 22 | NUM_GPUS=8 23 | NNODES=1 24 | RANK=${RANK:-0} # srun env node rank 25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} 26 | MASTER_PORT=${MASTER_PORT:-10086} 27 | 28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num 29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}" 30 | 31 | ############### Pretrain ################ 32 | 33 | PROMPT_VERSION=v1 34 | 35 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_proxyv_layer12_lr3e4" 36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" 37 | 38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs 39 | mkdir -p ${LOG_DIR} 40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log 41 | 42 | BATCH_SIZE=512 # 1 = 8 43 | PER_DEVICE_BATCH_SIZE=16 44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE)) 45 | 46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \ 47 | ./llava/train/train_mem.py \ 48 | --deepspeed scripts/zero3.json \ 49 | --model_name_or_path ${LLM_VERSION} \ 50 | --version ${PROMPT_VERSION} \ 51 | --data_path ${DATA_PATH} \ 52 | --image_folder ${IMAGE_FOLDER} \ 53 | --vision_tower ${VISION_MODEL_VERSION} \ 54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \ 55 | --mm_vision_mlp_lr 3e-4 \ 56 | --mm_vision_select_layer -2 \ 57 | --max_num_image_crops 1 \ 58 | --per_crop_token_len 576 \ 59 | --proxyv_reduce_factor 4 \ 60 | --proxyv True \ 61 | --proxyv_start_layer 12 \ 62 | --mm_projector_type mlp2x_gelu \ 63 | --mm_use_im_start_end False \ 64 | --mm_use_im_patch_token False \ 65 | --image_aspect_ratio pad \ 66 | --bf16 True \ 67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \ 68 | --num_train_epochs 1 \ 69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \ 70 | --per_device_eval_batch_size 4 \ 71 | --gradient_accumulation_steps $GRADIENT_ACC \ 72 | --evaluation_strategy "no" \ 73 | --save_strategy "steps" \ 74 | --save_total_limit 1 \ 75 | --save_steps 1000 \ 76 | --learning_rate 1e-3 \ 77 | --weight_decay 0. \ 78 | --warmup_ratio 0.03 \ 79 | --lr_scheduler_type "cosine" \ 80 | --logging_steps 1 \ 81 | --tf32 True \ 82 | --model_max_length 2048 \ 83 | --gradient_checkpointing True \ 84 | --dataloader_num_workers 4 \ 85 | --lazy_preprocess True \ 86 | --report_to wandb \ 87 | --run_name $BASE_RUN_NAME \ 88 | --attn_implementation sdpa \ 89 | 2>&1 | tee ${LOG_FILE} -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_fused_adamw.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": true, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu", 20 | "pin_memory": true 21 | }, 22 | "offload_param": { 23 | "device": "cpu", 24 | "pin_memory": true 25 | }, 26 | "overlap_comm": true, 27 | "contiguous_gradients": true, 28 | "sub_group_size": 1e9, 29 | "reduce_bucket_size": "auto" 30 | } 31 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /scripts/zero3pp.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "none", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "none", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "zero_quantized_weights": true, 36 | "zero_hpz_partition_size": 16, 37 | "zero_quantized_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | 47 | "gradient_accumulation_steps": "auto", 48 | "gradient_clipping": "auto", 49 | "steps_per_print": 100, 50 | "train_batch_size": "auto", 51 | "train_micro_batch_size_per_gpu": "auto", 52 | "wall_clock_breakdown": false 53 | } --------------------------------------------------------------------------------