├── .gitignore
├── LICENSE
├── README.md
├── assets
└── pipeline.png
├── demo.py
├── docs
├── .DS_Store
├── index.html
├── sitemap.xml
└── static
│ ├── .DS_Store
│ ├── css
│ ├── bulma-carousel.min.css
│ ├── bulma-slider.min.css
│ ├── bulma.css.map.txt
│ ├── bulma.min.css
│ ├── fontawesome.all.min.css
│ └── index.css
│ ├── js
│ ├── bulma-carousel.js
│ ├── bulma-carousel.min.js
│ ├── bulma-slider.js
│ ├── bulma-slider.min.js
│ ├── fontawesome.all.min.js
│ └── index.js
│ └── proxyv
│ ├── mask_ratio_revised.png
│ ├── non_spatial_proxyv.png
│ ├── proxyv.png
│ ├── results.png
│ ├── table1.png
│ ├── table2.png
│ └── table3.png
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_phi3.py
│ │ └── llava_qwen.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ ├── clip_encoder.py
│ │ ├── dev_eva_clip
│ │ │ ├── eva_clip
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ │ │ ├── constants.py
│ │ │ │ ├── eva_vit_model.py
│ │ │ │ ├── factory.py
│ │ │ │ ├── hf_configs.py
│ │ │ │ ├── hf_model.py
│ │ │ │ ├── loss.py
│ │ │ │ ├── model.py
│ │ │ │ ├── model_configs
│ │ │ │ │ ├── EVA-CLIP-18B.json
│ │ │ │ │ ├── EVA-CLIP-8B-plus.json
│ │ │ │ │ ├── EVA-CLIP-8B.json
│ │ │ │ │ ├── EVA01-CLIP-B-16.json
│ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json
│ │ │ │ │ ├── EVA01-CLIP-g-14.json
│ │ │ │ │ ├── EVA02-CLIP-B-16.json
│ │ │ │ │ ├── EVA02-CLIP-L-14-336.json
│ │ │ │ │ ├── EVA02-CLIP-L-14.json
│ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json
│ │ │ │ │ ├── EVA02-CLIP-bigE-14.json
│ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json
│ │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json
│ │ │ │ ├── modified_resnet.py
│ │ │ │ ├── openai.py
│ │ │ │ ├── pretrained.py
│ │ │ │ ├── rope.py
│ │ │ │ ├── timm_model.py
│ │ │ │ ├── tokenizer.py
│ │ │ │ ├── transform.py
│ │ │ │ ├── transformer.py
│ │ │ │ └── utils.py
│ │ │ └── eva_vit.py
│ │ ├── eva_clip
│ │ │ ├── eva_clip_encoder.py
│ │ │ ├── eva_clip_processors.py
│ │ │ ├── eva_vit.py
│ │ │ ├── factory.py
│ │ │ └── model_configs
│ │ │ │ ├── EVA-CLIP-18B.json
│ │ │ │ ├── EVA-CLIP-8B-plus.json
│ │ │ │ ├── EVA-CLIP-8B.json
│ │ │ │ ├── EVA01-CLIP-B-16.json
│ │ │ │ ├── EVA01-CLIP-g-14-plus.json
│ │ │ │ ├── EVA01-CLIP-g-14.json
│ │ │ │ ├── EVA02-CLIP-B-16.json
│ │ │ │ ├── EVA02-CLIP-L-14-336.json
│ │ │ │ ├── EVA02-CLIP-L-14.json
│ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json
│ │ │ │ ├── EVA02-CLIP-bigE-14.json
│ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json
│ │ │ │ └── Internal-EVA02-CLIP-10B-14.json
│ │ ├── hf_vision.py
│ │ ├── imagebind.py
│ │ ├── open_clip_encoder.py
│ │ └── siglip_encoder.py
│ ├── multimodal_projector
│ │ ├── builder.py
│ │ └── pooler_projector.py
│ ├── multimodal_resampler
│ │ ├── builder.py
│ │ ├── masked_drop.py
│ │ ├── perceiver.py
│ │ ├── qformer.py
│ │ └── spatial_pool.py
│ ├── utils.py
│ └── vision_mlp.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── llava_trainer_eval.py
│ ├── train.py
│ └── train_mem.py
└── utils.py
├── pyproject.toml
└── scripts
├── finetune
├── finetune_vicuna7b_baseline.sh
└── finetune_vicuna7b_proxyv.sh
├── pretrain
├── pretrain_llama8b_baseline.sh
├── pretrain_llama8b_proxyv.sh
├── pretrain_phi3b_baseline.sh
├── pretrain_phi3b_proxyv.sh
├── pretrain_qwen7b_baseline.sh
├── pretrain_qwen7b_proxyv.sh
├── pretrain_vicuna13b_baselin.sh
├── pretrain_vicuna13b_proxyv.sh
├── pretrain_vicuna7b_baseline.sh
└── pretrain_vicuna7b_proxyv.sh
├── zero2.json
├── zero2_fused_adamw.json
├── zero2_offload.json
├── zero3.json
├── zero3_offload.json
└── zero3pp.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProxyV: Streamline Without Sacrifice - Squeeze out Computation Redundancy in LMM
2 |
3 | [](https://arxiv.org/abs/2505.15816)
4 | [](https://penghao-wu.github.io/ProxyV/)
5 | [](https://huggingface.co/craigwu/proxyv_vicuna_7b_layer12)
6 |
7 | 
8 |
9 | ## Contents:
10 | 1. [Getting Started](#start)
11 | 2. [Image Encoding Scheme](#encoding)
12 | 3. [Training](#training)
13 | 4. [Evaluation](#evaluation)
14 | 5. [License](#license)
15 | 6. [Citation](#citation)
16 | 7. [Acknowledgement](#acknowledgement)
17 |
18 | ## Getting Started
19 |
20 | ### Installation
21 | ```
22 | conda create -n proxyv python=3.10 -y
23 | conda activate proxyv
24 | pip install --upgrade pip # Enable PEP 660 support.
25 | pip install -e ".[train]"
26 | ```
27 |
28 | ### Training Dataset
29 | For the pre-training stage, we use the 1.2M ShareGPT4V data which can be downloaded at this [link](https://huggingface.co/datasets/Lin-Chen/ShareGPT4V)
30 | For the fine-tuning stage, we use the public LLaVA-Next data which can be downloaded at this [link](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Data)
31 |
32 | ## Image Encoding Scheme
33 | In our current implementation, we adopt the [AnyRes](https://github.com/LLaVA-VL/LLaVA-NeXT) strategy. The image features within each crop are flattened in raster order and concatenated crop by crop, similar to the [UniRes](https://github.com/EvolvingLMMs-Lab/LongVA) strategy. We also append a newline separator token after each crop.
34 | To process the vision tokens more conveniently, we pack tokens in the **\[vision tokens; proxy tokens; newline separator tokens; text tokens\]** order, and modify the position_ids and attention_masks accordingly to preserve their original relative order.
35 |
36 | ## Training
37 | The pre-training scripts can be found within the ``scripts/pretrain`` folder, and fine-tuning example scripts are provided under the ``scripts/finetune`` folder.
38 | To enable ProxyV, set ``--proxyv`` to ``true`` in the script and set ``--proxyv_start_layer`` to the desired layer index.
39 |
40 | ## Evaluation
41 | The vicuna-1.5-7B ProxyV layer-12 model studied in the paper is provided at this [link](vicuna-1.5-7B ProxyV layer-12 model).
42 | A simple inference example script is provided at ``demo.py``.
43 | All benchmark evaluations can be directly conducted using [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) with `--model` set to `llava`.
44 |
45 | ## License
46 |
47 | This project is under the Apache-2.0 license. See [LICENSE](LICENSE) for details.
48 |
49 | ## Citation
50 | Please consider citing our paper if you find this project helpful for your research:
51 |
52 | ```bibtex
53 | @article{ProxyV,
54 | author = {Wu, Penghao and Lu, Lewei and Liu, Ziwei},
55 | title = {Streamline Without Sacrifice - Squeeze out Computation Redundancy in LMM},
56 | journal={arXiv preprint arXiv:2505.15816},
57 | year={2025}}
58 | ```
59 |
60 | ## Acknowledgement
61 | - This work is built upon [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT).
62 |
63 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/assets/pipeline.png
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 |
4 |
5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
6 | from llava.conversation import conv_templates
7 | from llava.model.builder import load_pretrained_model
8 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
9 |
10 |
11 |
12 | def process(image, question, tokenizer, image_processor, model_config):
13 | qs = question
14 | if model_config.mm_use_im_start_end:
15 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
16 | else:
17 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
18 |
19 | conv = conv_templates[conv_mode].copy()
20 | conv.append_message(conv.roles[0], qs)
21 | conv.append_message(conv.roles[1], None)
22 | prompt = conv.get_prompt()
23 | print(prompt)
24 |
25 | image_size = [image.size]
26 | image_tensor = process_images([image], image_processor, model_config).to(torch.float16)
27 | # image_tensor = process_images([image], image_processor, model_config).to(torch.bfloat16)
28 |
29 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
30 |
31 | return input_ids, image_tensor, image_size, prompt
32 |
33 |
34 | conv_mode = "v1"
35 | temperature = 0.0
36 | model_path = "craigwu/proxyv_vicuna_7b_layer12"
37 | model_name = get_model_name_from_path(model_path)
38 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, multimodal=True,attn_implementation='sdpa')
39 |
40 | while True:
41 | image_path = input("image path: ")
42 | image = Image.open(image_path).convert('RGB')
43 | question = input("question: ")
44 |
45 | input_ids, image_tensor, image_sizes, prompt = process(image, question, tokenizer, image_processor, model.config)
46 | input_ids = input_ids.to(device='cuda', non_blocking=True)
47 |
48 | with torch.inference_mode():
49 | output_ids = model.generate(
50 | input_ids,
51 | images=image_tensor,
52 | image_sizes=image_sizes,
53 | do_sample=True if temperature > 0 else False,
54 | temperature=temperature,
55 | num_beams=1,
56 | max_new_tokens=512,
57 | use_cache=True)
58 |
59 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
60 |
61 | print(outputs)
--------------------------------------------------------------------------------
/docs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/.DS_Store
--------------------------------------------------------------------------------
/docs/sitemap.xml:
--------------------------------------------------------------------------------
1 |
2 |
7 |
8 |
9 |
10 |
11 | https://amap-ml.github.io/UniVG-R1-page/
12 | 2025-05-20T08:20:38+00:00
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/docs/static/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/.DS_Store
--------------------------------------------------------------------------------
/docs/static/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10}
--------------------------------------------------------------------------------
/docs/static/css/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Noto Sans', sans-serif;
3 | }
4 | .content-flex-row {
5 | display: flex;
6 | align-items: flex-start;
7 | justify-content: center;
8 | gap: 32px; /* 左右间距可调整 */
9 | margin-bottom: 2rem;
10 | }
11 |
12 | .footer .icon-link {
13 | font-size: 25px;
14 | color: #000;
15 | }
16 |
17 | .link-block a {
18 | margin-top: 5px;
19 | margin-bottom: 5px;
20 | }
21 |
22 | .dnerf {
23 | font-variant: small-caps;
24 | }
25 |
26 |
27 | .teaser .hero-body {
28 | padding-top: 0;
29 | padding-bottom: 3rem;
30 | }
31 |
32 | .teaser {
33 | font-family: 'Google Sans', sans-serif;
34 | }
35 |
36 |
37 | .publication-title {
38 | }
39 |
40 | .publication-banner {
41 | max-height: parent;
42 |
43 | }
44 |
45 | .publication-banner video {
46 | position: relative;
47 | left: auto;
48 | top: auto;
49 | transform: none;
50 | object-fit: fit;
51 | }
52 |
53 | .publication-header .hero-body {
54 | }
55 |
56 | .publication-title {
57 | font-family: 'Google Sans', sans-serif;
58 | }
59 |
60 | .publication-authors {
61 | font-family: 'Google Sans', sans-serif;
62 | }
63 |
64 | .publication-venue {
65 | color: #555;
66 | width: fit-content;
67 | font-weight: bold;
68 | }
69 |
70 | .publication-awards {
71 | color: #ff3860;
72 | width: fit-content;
73 | font-weight: bolder;
74 | }
75 |
76 | .publication-authors {
77 | }
78 |
79 | .publication-authors a {
80 | color: hsl(204, 86%, 53%) !important;
81 | }
82 |
83 | .publication-authors a:hover {
84 | text-decoration: underline;
85 | }
86 |
87 | .author-block {
88 | display: inline-block;
89 | }
90 |
91 | .publication-banner img {
92 | }
93 |
94 | .publication-authors {
95 | /*color: #4286f4;*/
96 | }
97 |
98 | .publication-video {
99 | position: relative;
100 | width: 100%;
101 | height: 0;
102 | padding-bottom: 56.25%;
103 |
104 | overflow: hidden;
105 | border-radius: 10px !important;
106 | }
107 |
108 | .publication-video iframe {
109 | position: absolute;
110 | top: 0;
111 | left: 0;
112 | width: 100%;
113 | height: 100%;
114 | }
115 |
116 | .publication-body img {
117 | }
118 |
119 | .results-carousel {
120 | overflow: hidden;
121 | }
122 |
123 | .results-carousel .item {
124 | margin: 5px;
125 | overflow: hidden;
126 | border: 1px solid #bbb;
127 | border-radius: 10px;
128 | padding: 0;
129 | font-size: 0;
130 | }
131 |
132 | .results-carousel video {
133 | margin: 0;
134 | }
135 |
136 |
137 | .interpolation-panel {
138 | background: #f5f5f5;
139 | border-radius: 10px;
140 | }
141 |
142 | .interpolation-panel .interpolation-image {
143 | width: 100%;
144 | border-radius: 5px;
145 | }
146 |
147 | .interpolation-video-column {
148 | }
149 |
150 | .interpolation-panel .slider {
151 | margin: 0 !important;
152 | }
153 |
154 | .interpolation-panel .slider {
155 | margin: 0 !important;
156 | }
157 |
158 | #interpolation-image-wrapper {
159 | width: 100%;
160 | }
161 | #interpolation-image-wrapper img {
162 | border-radius: 5px;
163 | }
164 |
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 | var INTERP_BASE = "./static/interpolation/stacked";
4 | var NUM_INTERP_FRAMES = 240;
5 |
6 | var interp_images = [];
7 | function preloadInterpolationImages() {
8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) {
9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg';
10 | interp_images[i] = new Image();
11 | interp_images[i].src = path;
12 | }
13 | }
14 |
15 | function setInterpolationImage(i) {
16 | var image = interp_images[i];
17 | image.ondragstart = function() { return false; };
18 | image.oncontextmenu = function() { return false; };
19 | $('#interpolation-image-wrapper').empty().append(image);
20 | }
21 |
22 |
23 | $(document).ready(function() {
24 | // Check for click events on the navbar burger icon
25 | $(".navbar-burger").click(function() {
26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
27 | $(".navbar-burger").toggleClass("is-active");
28 | $(".navbar-menu").toggleClass("is-active");
29 |
30 | });
31 |
32 | var options = {
33 | slidesToScroll: 1,
34 | slidesToShow: 3,
35 | loop: true,
36 | infinite: true,
37 | autoplay: false,
38 | autoplaySpeed: 3000,
39 | }
40 |
41 | // Initialize all div with carousel class
42 | var carousels = bulmaCarousel.attach('.carousel', options);
43 |
44 | // Loop on each carousel initialized
45 | for(var i = 0; i < carousels.length; i++) {
46 | // Add listener to event
47 | carousels[i].on('before:show', state => {
48 | console.log(state);
49 | });
50 | }
51 |
52 | // Access to bulmaCarousel instance of an element
53 | var element = document.querySelector('#my-element');
54 | if (element && element.bulmaCarousel) {
55 | // bulmaCarousel instance is available as element.bulmaCarousel
56 | element.bulmaCarousel.on('before-show', function(state) {
57 | console.log(state);
58 | });
59 | }
60 |
61 | /*var player = document.getElementById('interpolation-video');
62 | player.addEventListener('loadedmetadata', function() {
63 | $('#interpolation-slider').on('input', function(event) {
64 | console.log(this.value, player.duration);
65 | player.currentTime = player.duration / 100 * this.value;
66 | })
67 | }, false);*/
68 | preloadInterpolationImages();
69 |
70 | $('#interpolation-slider').on('input', function(event) {
71 | setInterpolationImage(this.value);
72 | });
73 | setInterpolationImage(0);
74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1);
75 |
76 | bulmaSlider.attach();
77 |
78 | })
79 |
--------------------------------------------------------------------------------
/docs/static/proxyv/mask_ratio_revised.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/mask_ratio_revised.png
--------------------------------------------------------------------------------
/docs/static/proxyv/non_spatial_proxyv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/non_spatial_proxyv.png
--------------------------------------------------------------------------------
/docs/static/proxyv/proxyv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/proxyv.png
--------------------------------------------------------------------------------
/docs/static/proxyv/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/results.png
--------------------------------------------------------------------------------
/docs/static/proxyv/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/table1.png
--------------------------------------------------------------------------------
/docs/static/proxyv/table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/table2.png
--------------------------------------------------------------------------------
/docs/static/proxyv/table3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/docs/static/proxyv/table3.png
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | AVAILABLE_MODELS = {
4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig",
5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig",
6 | "llava_phi3": "LlavaPhiForCausalLM, LlavaPhiConfig",
7 | # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig",
8 | # Add other models as needed
9 | }
10 |
11 | for model_name, model_classes in AVAILABLE_MODELS.items():
12 | exec(f"from .language_model.{model_name} import {model_classes}")
13 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 |
6 | import argparse
7 |
8 | import torch
9 | from tqdm import tqdm
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | from llava import LlavaLlamaForCausalLM
12 |
13 |
14 | def apply_delta(base_model_path, target_model_path, delta_path):
15 | print("Loading base model")
16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31 | bparam = base.state_dict()[name]
32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
33 |
34 | print("Saving target model")
35 | delta.save_pretrained(target_model_path)
36 | delta_tokenizer.save_pretrained(target_model_path)
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument("--base-model-path", type=str, required=True)
42 | parser.add_argument("--target-model-path", type=str, required=True)
43 | parser.add_argument("--delta-path", type=str, required=True)
44 |
45 | args = parser.parse_args()
46 |
47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
48 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 |
6 | import argparse
7 |
8 | import torch
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model import *
11 | from llava.model.utils import auto_upgrade
12 |
13 |
14 | def consolidate_ckpt(src_path, dst_path):
15 | print("Loading model")
16 | auto_upgrade(src_path)
17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
19 | src_model.save_pretrained(dst_path)
20 | src_tokenizer.save_pretrained(dst_path)
21 |
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--src", type=str, required=True)
26 | parser.add_argument("--dst", type=str, required=True)
27 |
28 | args = parser.parse_args()
29 |
30 | consolidate_ckpt(args.src, args.dst)
31 |
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 |
6 | import argparse
7 |
8 | import torch
9 | from tqdm import tqdm
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 | from llava.model.utils import auto_upgrade
12 |
13 |
14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
15 | print("Loading base model")
16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model"
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
31 | bparam = base.state_dict()[name]
32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 | from .imagebind import ImageBindWrapper
4 | from .open_clip_encoder import OpenCLIPVisionTower
5 | from .hf_vision import HFVisionTower
6 | from .siglip_encoder import SigLipVisionTower
7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
8 |
9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper
11 |
12 |
13 | def build_vision_tower(vision_tower_cfg, **kwargs):
14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
15 | is_absolute_path_exists = os.path.exists(vision_tower)
16 | use_s2 = getattr(vision_tower_cfg, "s2", False)
17 | if 'clip' in vision_tower or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
18 | if use_s2:
19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
20 | else:
21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
22 | elif "siglip" in vision_tower:
23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
24 | elif vision_tower.startswith("hf:"):
25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
26 | elif vision_tower in ["imagebind_huge"]:
27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
28 | elif vision_tower.startswith("open_clip_hub"):
29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower():
31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]:
33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs)
34 |
35 | raise ValueError(f"Unknown vision tower: {vision_tower}")
36 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from llava.utils import rank0_print
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 | try:
7 | from s2wrapper import forward as multiscale_forward
8 | except:
9 | pass
10 |
11 |
12 | class CLIPVisionTower(nn.Module):
13 | def __init__(self, vision_tower, args, delay_load=False):
14 | super().__init__()
15 |
16 | self.is_loaded = False
17 |
18 | self.vision_tower_name = vision_tower
19 | self.select_layer = args.mm_vision_select_layer
20 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
21 |
22 | if not delay_load:
23 | rank0_print(f"Loading vision tower: {vision_tower}")
24 | self.load_model()
25 | elif getattr(args, "unfreeze_mm_vision_tower", False):
26 | # TODO: better detector is needed.
27 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
28 | self.load_model()
29 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
30 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
31 | self.load_model()
32 | else:
33 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
34 |
35 | def load_model(self, device_map=None):
36 | if self.is_loaded:
37 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
38 | return
39 |
40 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
41 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
42 | self.vision_tower.requires_grad_(False)
43 |
44 | self.is_loaded = True
45 |
46 | def feature_select(self, image_forward_outs):
47 | select_feature_type = self.select_feature
48 |
49 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
50 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4
51 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
52 | select_feature_type = select_feature_type.replace("slicefour_", "")
53 | elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
54 | select_layers = [-2, -5, -8, -11, 6]
55 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
56 | select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
57 | else:
58 | image_features = image_forward_outs.hidden_states[self.select_layer]
59 |
60 | if select_feature_type == "patch":
61 | image_features = image_features[:, 1:]
62 | elif select_feature_type == "cls_patch":
63 | image_features = image_features
64 | else:
65 | raise ValueError(f"Unexpected select feature: {select_feature_type}")
66 | return image_features
67 |
68 | def forward(self, images):
69 | if type(images) is list:
70 | image_features = []
71 | for image in images:
72 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
73 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
74 | image_features.append(image_feature)
75 | else:
76 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
77 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
78 |
79 | return image_features
80 |
81 | @property
82 | def dummy_feature(self):
83 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
84 |
85 | @property
86 | def dtype(self):
87 | return self.vision_tower.dtype
88 |
89 | @property
90 | def device(self):
91 | return self.vision_tower.device
92 |
93 | @property
94 | def config(self):
95 | if self.is_loaded:
96 | return self.vision_tower.config
97 | else:
98 | return self.cfg_only
99 |
100 | @property
101 | def hidden_size(self):
102 | _hidden_size = self.config.hidden_size
103 | if "slicefour" in self.select_feature:
104 | _hidden_size *= 4
105 | if "slice_m25811_f6" in self.select_feature:
106 | _hidden_size *= 5
107 | return _hidden_size
108 |
109 | @property
110 | def num_patches_per_side(self):
111 | return self.config.image_size // self.config.patch_size
112 |
113 | @property
114 | def num_patches(self):
115 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2
116 | if "cls_patch" in self.select_feature:
117 | _num_patches += 1
118 | return _num_patches
119 |
120 | @property
121 | def image_size(self):
122 | return self.config.image_size
123 |
124 |
125 | class CLIPVisionTowerS2(CLIPVisionTower):
126 | def __init__(self, vision_tower, args, delay_load=False):
127 |
128 | self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
129 | self.s2_scales = list(map(int, self.s2_scales.split(",")))
130 | self.s2_scales.sort()
131 | self.s2_split_size = self.s2_scales[0]
132 | self.s2_image_size = self.s2_scales[-1]
133 |
134 | super().__init__(vision_tower, args, delay_load)
135 |
136 | # change resize/crop size in preprocessing to the largest image size in s2_scale
137 | if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
138 | self.image_processor.size["shortest_edge"] = self.s2_image_size
139 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
140 |
141 | def load_model(self, device_map=None):
142 | if self.is_loaded:
143 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
144 | return
145 |
146 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
147 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
148 | self.vision_tower.requires_grad_(False)
149 |
150 | self.image_processor.size["shortest_edge"] = self.s2_image_size
151 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
152 |
153 | self.is_loaded = True
154 |
155 | def forward_feature(self, images):
156 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
157 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
158 | return image_features
159 |
160 | def forward(self, images):
161 | if type(images) is list:
162 | image_features = []
163 | for image in images:
164 | image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
165 | image_features.append(image_feature)
166 | else:
167 | image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
168 |
169 | return image_features
170 |
171 | @property
172 | def hidden_size(self):
173 | return self.config.hidden_size * len(self.s2_scales)
174 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4 | from .loss import ClipLoss
5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
6 | from .openai import load_openai_model, list_openai_models
7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
8 | from .tokenizer import SimpleTokenizer, tokenize
9 | from .transform import image_transform
10 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/penghao-wu/ProxyV/42d6d97ba53acd0441a73403013b530f9d2710a0/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings",
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings",
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens",
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | "bert": {
46 | "config_names": {
47 | "context_length": "max_position_embeddings",
48 | "vocab_size": "vocab_size",
49 | "width": "hidden_size",
50 | "heads": "num_attention_heads",
51 | "layers": "num_hidden_layers",
52 | "layer_attr": "layer",
53 | "token_embeddings_attr": "embeddings",
54 | },
55 | "pooler": "mean_pooler",
56 | },
57 | }
58 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 |
6 | try:
7 | import torch.distributed.nn
8 | from torch import distributed as dist
9 |
10 | has_distributed = True
11 | except ImportError:
12 | has_distributed = False
13 |
14 | try:
15 | import horovod.torch as hvd
16 | except ImportError:
17 | hvd = None
18 |
19 | from timm.loss import LabelSmoothingCrossEntropy
20 |
21 |
22 | def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False):
23 | assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
24 | if use_horovod:
25 | assert hvd is not None, "Please install horovod"
26 | if gather_with_grad:
27 | all_image_features = hvd.allgather(image_features)
28 | all_text_features = hvd.allgather(text_features)
29 | else:
30 | with torch.no_grad():
31 | all_image_features = hvd.allgather(image_features)
32 | all_text_features = hvd.allgather(text_features)
33 | if not local_loss:
34 | # ensure grads for local rank when all_* features don't have a gradient
35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
37 | gathered_image_features[rank] = image_features
38 | gathered_text_features[rank] = text_features
39 | all_image_features = torch.cat(gathered_image_features, dim=0)
40 | all_text_features = torch.cat(gathered_text_features, dim=0)
41 | else:
42 | # We gather tensors from all gpus
43 | if gather_with_grad:
44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
46 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
47 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
48 | else:
49 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
50 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
51 | dist.all_gather(gathered_image_features, image_features)
52 | dist.all_gather(gathered_text_features, text_features)
53 | if not local_loss:
54 | # ensure grads for local rank when all_* features don't have a gradient
55 | gathered_image_features[rank] = image_features
56 | gathered_text_features[rank] = text_features
57 | all_image_features = torch.cat(gathered_image_features, dim=0)
58 | all_text_features = torch.cat(gathered_text_features, dim=0)
59 |
60 | return all_image_features, all_text_features
61 |
62 |
63 | class ClipLoss(nn.Module):
64 |
65 | def __init__(
66 | self,
67 | local_loss=False,
68 | gather_with_grad=False,
69 | cache_labels=False,
70 | rank=0,
71 | world_size=1,
72 | use_horovod=False,
73 | smoothing=0.0,
74 | ):
75 | super().__init__()
76 | self.local_loss = local_loss
77 | self.gather_with_grad = gather_with_grad
78 | self.cache_labels = cache_labels
79 | self.rank = rank
80 | self.world_size = world_size
81 | self.use_horovod = use_horovod
82 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
83 |
84 | # cache state
85 | self.prev_num_logits = 0
86 | self.labels = {}
87 |
88 | def forward(self, image_features, text_features, logit_scale=1.0):
89 | device = image_features.device
90 | if self.world_size > 1:
91 | all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
92 |
93 | if self.local_loss:
94 | logits_per_image = logit_scale * image_features @ all_text_features.T
95 | logits_per_text = logit_scale * text_features @ all_image_features.T
96 | else:
97 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
98 | logits_per_text = logits_per_image.T
99 | else:
100 | logits_per_image = logit_scale * image_features @ text_features.T
101 | logits_per_text = logit_scale * text_features @ image_features.T
102 | # calculated ground-truth and cache if enabled
103 | num_logits = logits_per_image.shape[0]
104 | if self.prev_num_logits != num_logits or device not in self.labels:
105 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
106 | if self.world_size > 1 and self.local_loss:
107 | labels = labels + num_logits * self.rank
108 | if self.cache_labels:
109 | self.labels[device] = labels
110 | self.prev_num_logits = num_logits
111 | else:
112 | labels = self.labels[device]
113 |
114 | if self.label_smoothing_cross_entropy:
115 | total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2
116 | else:
117 | total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
118 |
119 | acc = None
120 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
121 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
122 | acc = {"i2t": i2t_acc, "t2i": t2i_acc}
123 | return total_loss, acc
124 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1536,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 48,
6 | "width": 5120,
7 | "head_width": 128,
8 | "mlp_ratio": 5,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-18b-14-x",
11 | "drop_path_rate": 0,
12 | "qkv_bias": false,
13 | "xattn": true,
14 | "postnorm": true,
15 | "fusedLN": false,
16 | "use_rms_norm": true
17 | },
18 | "text_cfg": {
19 | "context_length": 77,
20 | "vocab_size": 49408,
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "xattn": false,
25 | "fusedLN": false
26 | }
27 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 448,
5 | "layers": 32,
6 | "width": 4096,
7 | "head_width": 128,
8 | "mlp_ratio": 5,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-8b-14-plus-x",
11 | "drop_path_rate": 0,
12 | "qkv_bias": false,
13 | "xattn": true,
14 | "postnorm": false,
15 | "fusedLN": false,
16 | "use_rms_norm": true
17 | },
18 | "text_cfg": {
19 | "context_length": 77,
20 | "vocab_size": 49408,
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "xattn": false,
25 | "fusedLN": false
26 | }
27 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 4096,
7 | "head_width": 128,
8 | "mlp_ratio": 5,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-8b-14-x",
11 | "drop_path_rate": 0,
12 | "qkv_bias": false,
13 | "xattn": true,
14 | "postnorm": false,
15 | "fusedLN": false,
16 | "use_rms_norm": true
17 | },
18 | "text_cfg": {
19 | "context_length": 77,
20 | "vocab_size": 49408,
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "xattn": false,
25 | "fusedLN": false
26 | }
27 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "eva_model_name": "eva-clip-b-16",
9 | "ls_init_value": 0.1,
10 | "drop_path_rate": 0.0
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 1024,
19 | "heads": 16,
20 | "layers": 24,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0.4,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 768,
19 | "heads": 12,
20 | "layers": 12,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "head_width": 64,
8 | "patch_size": 16,
9 | "mlp_ratio": 2.6667,
10 | "eva_model_name": "eva-clip-b-16-X",
11 | "drop_path_rate": 0.0,
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 512,
24 | "heads": 8,
25 | "layers": 12,
26 | "xattn": true,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14-336",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 448,
5 | "layers": 77,
6 | "width": 2304,
7 | "head_width": 144,
8 | "mlp_ratio": 10.9722,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-10b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": false,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 77,
6 | "width": 2304,
7 | "head_width": 144,
8 | "mlp_ratio": 10.9722,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-10b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": false,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from .utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))]))
37 |
38 | def forward(self, x: torch.Tensor):
39 | identity = x
40 |
41 | out = self.act1(self.bn1(self.conv1(x)))
42 | out = self.act2(self.bn2(self.conv2(out)))
43 | out = self.avgpool(out)
44 | out = self.bn3(self.conv3(out))
45 |
46 | if self.downsample is not None:
47 | identity = self.downsample(x)
48 |
49 | out += identity
50 | out = self.act3(out)
51 | return out
52 |
53 |
54 | class AttentionPool2d(nn.Module):
55 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
56 | super().__init__()
57 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
58 | self.k_proj = nn.Linear(embed_dim, embed_dim)
59 | self.q_proj = nn.Linear(embed_dim, embed_dim)
60 | self.v_proj = nn.Linear(embed_dim, embed_dim)
61 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
62 | self.num_heads = num_heads
63 |
64 | def forward(self, x):
65 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
66 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
67 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
68 | x, _ = F.multi_head_attention_forward(
69 | query=x,
70 | key=x,
71 | value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0.0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False,
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.image_size = image_size
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.act1 = nn.ReLU(inplace=True)
110 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
111 | self.bn2 = nn.BatchNorm2d(width // 2)
112 | self.act2 = nn.ReLU(inplace=True)
113 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
114 | self.bn3 = nn.BatchNorm2d(width)
115 | self.act3 = nn.ReLU(inplace=True)
116 | self.avgpool = nn.AvgPool2d(2)
117 |
118 | # residual layers
119 | self._inplanes = width # this is a *mutable* variable used during construction
120 | self.layer1 = self._make_layer(width, layers[0])
121 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
122 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
123 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
124 |
125 | embed_dim = width * 32 # the ResNet feature dimension
126 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
127 |
128 | self.init_parameters()
129 |
130 | def _make_layer(self, planes, blocks, stride=1):
131 | layers = [Bottleneck(self._inplanes, planes, stride)]
132 |
133 | self._inplanes = planes * Bottleneck.expansion
134 | for _ in range(1, blocks):
135 | layers.append(Bottleneck(self._inplanes, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def init_parameters(self):
140 | if self.attnpool is not None:
141 | std = self.attnpool.c_proj.in_features**-0.5
142 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
143 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
144 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
146 |
147 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
148 | for name, param in resnet_block.named_parameters():
149 | if name.endswith("bn3.weight"):
150 | nn.init.zeros_(param)
151 |
152 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
153 | assert unlocked_groups == 0, "partial locking not currently supported for this model"
154 | for param in self.parameters():
155 | param.requires_grad = False
156 | if freeze_bn_stats:
157 | freeze_batch_norm_2d(self)
158 |
159 | @torch.jit.ignore
160 | def set_grad_checkpointing(self, enable=True):
161 | # FIXME support for non-transformer
162 | pass
163 |
164 | def stem(self, x):
165 | x = self.act1(self.bn1(self.conv1(x)))
166 | x = self.act2(self.bn2(self.conv2(x)))
167 | x = self.act3(self.bn3(self.conv3(x)))
168 | x = self.avgpool(x)
169 | return x
170 |
171 | def forward(self, x):
172 | x = self.stem(x)
173 | x = self.layer1(x)
174 | x = self.layer2(x)
175 | x = self.layer3(x)
176 | x = self.layer4(x)
177 | x = self.attnpool(x)
178 |
179 | return x
180 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag("openai")
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 | jit: bool = True,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 | cache_dir : Optional[str]
43 | The directory to cache the downloaded model weights
44 |
45 | Returns
46 | -------
47 | model : torch.nn.Module
48 | The CLIP model
49 | preprocess : Callable[[PIL.Image], torch.Tensor]
50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51 | """
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = "fp32" if device == "cpu" else "fp16"
56 |
57 | if get_pretrained_url(name, "openai"):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir)
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67 | state_dict = None
68 | except RuntimeError:
69 | # loading saved state dict
70 | if jit:
71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72 | jit = False
73 | state_dict = torch.load(model_path, map_location="cpu")
74 |
75 | if not jit:
76 | # Build a non-jit model from the OpenAI jitted model state dict
77 | cast_dtype = get_cast_dtype(precision)
78 | try:
79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80 | except KeyError:
81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83 |
84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85 | model = model.to(device)
86 | if precision.startswith("amp") or precision == "fp32":
87 | model.float()
88 | elif precision == "bf16":
89 | convert_weights_to_lp(model, dtype=torch.bfloat16)
90 |
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96 |
97 | def patch_device(module):
98 | try:
99 | graphs = [module.graph] if hasattr(module, "graph") else []
100 | except RuntimeError:
101 | graphs = []
102 |
103 | if hasattr(module, "forward1"):
104 | graphs.append(module.forward1.graph)
105 |
106 | for graph in graphs:
107 | for node in graph.findAllNodes("prim::Constant"):
108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109 | node.copyAttributes(device_node)
110 |
111 | model.apply(patch_device)
112 | patch_device(model.encode_image)
113 | patch_device(model.encode_text)
114 |
115 | # patch dtype to float32 (typically for CPU)
116 | if precision == "fp32":
117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119 | float_node = float_input.node()
120 |
121 | def patch_float(module):
122 | try:
123 | graphs = [module.graph] if hasattr(module, "graph") else []
124 | except RuntimeError:
125 | graphs = []
126 |
127 | if hasattr(module, "forward1"):
128 | graphs.append(module.forward1.graph)
129 |
130 | for graph in graphs:
131 | for node in graph.findAllNodes("aten::to"):
132 | inputs = list(node.inputs())
133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134 | if inputs[i].node()["value"] == 5:
135 | inputs[i].node().copyAttributes(float_node)
136 |
137 | model.apply(patch_float)
138 | patch_float(model.encode_image)
139 | patch_float(model.encode_text)
140 | model.float()
141 |
142 | # ensure image_size attr available at consistent location for both jit and non-jit
143 | model.visual.image_size = model.input_resolution.item()
144 | return model
145 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | import torch
3 | from torch import nn
4 | from einops import rearrange, repeat
5 | import logging
6 |
7 |
8 | def broadcat(tensors, dim=-1):
9 | num_tensors = len(tensors)
10 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
11 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
12 | shape_len = list(shape_lens)[0]
13 | dim = (dim + shape_len) if dim < 0 else dim
14 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
15 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
16 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
17 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
18 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
19 | expanded_dims.insert(dim, (dim, dims[dim]))
20 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
21 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
22 | return torch.cat(tensors, dim=dim)
23 |
24 |
25 | def rotate_half(x):
26 | x = rearrange(x, "... (d r) -> ... d r", r=2)
27 | x1, x2 = x.unbind(dim=-1)
28 | x = torch.stack((-x2, x1), dim=-1)
29 | return rearrange(x, "... d r -> ... (d r)")
30 |
31 |
32 | class VisionRotaryEmbedding(nn.Module):
33 | def __init__(
34 | self,
35 | dim,
36 | pt_seq_len,
37 | ft_seq_len=None,
38 | custom_freqs=None,
39 | freqs_for="lang",
40 | theta=10000,
41 | max_freq=10,
42 | num_freqs=1,
43 | ):
44 | super().__init__()
45 | if custom_freqs:
46 | freqs = custom_freqs
47 | elif freqs_for == "lang":
48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
49 | elif freqs_for == "pixel":
50 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
51 | elif freqs_for == "constant":
52 | freqs = torch.ones(num_freqs).float()
53 | else:
54 | raise ValueError(f"unknown modality {freqs_for}")
55 |
56 | if ft_seq_len is None:
57 | ft_seq_len = pt_seq_len
58 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
59 |
60 | freqs_h = torch.einsum("..., f -> ... f", t, freqs)
61 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
62 |
63 | freqs_w = torch.einsum("..., f -> ... f", t, freqs)
64 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
65 |
66 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
67 |
68 | self.register_buffer("freqs_cos", freqs.cos())
69 | self.register_buffer("freqs_sin", freqs.sin())
70 |
71 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
72 |
73 | def forward(self, t, start_index=0):
74 | rot_dim = self.freqs_cos.shape[-1]
75 | end_index = start_index + rot_dim
76 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
78 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
79 |
80 | return torch.cat((t_left, t, t_right), dim=-1)
81 |
82 |
83 | class VisionRotaryEmbeddingFast(nn.Module):
84 | def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
85 | super().__init__()
86 | if custom_freqs:
87 | freqs = custom_freqs
88 | elif freqs_for == "lang":
89 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
90 | elif freqs_for == "pixel":
91 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
92 | elif freqs_for == "constant":
93 | freqs = torch.ones(num_freqs).float()
94 | else:
95 | raise ValueError(f"unknown modality {freqs_for}")
96 |
97 | if ft_seq_len is None:
98 | ft_seq_len = pt_seq_len
99 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
100 |
101 | freqs = torch.einsum("..., f -> ... f", t, freqs)
102 | freqs = repeat(freqs, "... n -> ... (n r)", r=2)
103 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
104 |
105 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
106 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
107 |
108 | self.patch_dropout = patch_dropout
109 |
110 | self.register_buffer("freqs_cos", freqs_cos)
111 | self.register_buffer("freqs_sin", freqs_sin)
112 |
113 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
114 |
115 | def forward(self, t, patch_indices_keep=None):
116 | if patch_indices_keep is not None:
117 | batch = t.size()[0]
118 | batch_indices = torch.arange(batch)
119 | batch_indices = batch_indices[..., None]
120 |
121 | freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
122 | freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
123 |
124 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
125 | freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
126 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
127 | freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
128 |
129 | return t * freqs_cos + rotate_half(t) * freqs_sin
130 |
131 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
132 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 |
6 | import logging
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | try:
13 | import timm
14 | from timm.models.layers import Mlp, to_2tuple
15 |
16 | try:
17 | # old timm imports < 0.8.1
18 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
19 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
20 | except ImportError:
21 | # new timm imports >= 0.8.1
22 | from timm.layers import RotAttentionPool2d
23 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
24 | except ImportError:
25 | timm = None
26 |
27 | from .utils import freeze_batch_norm_2d
28 |
29 |
30 | class TimmModel(nn.Module):
31 | """timm model adapter
32 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
33 | """
34 |
35 | def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False):
36 | super().__init__()
37 | if timm is None:
38 | raise RuntimeError("Please `pip install timm` to use timm models.")
39 |
40 | self.image_size = to_2tuple(image_size)
41 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
42 | feat_size = self.trunk.default_cfg.get("pool_size", None)
43 | feature_ndim = 1 if not feat_size else 2
44 | if pool in ("abs_attn", "rot_attn"):
45 | assert feature_ndim == 2
46 | # if attn pooling used, remove both classifier and default pool
47 | self.trunk.reset_classifier(0, global_pool="")
48 | else:
49 | # reset global pool if pool config set, otherwise leave as network default
50 | reset_kwargs = dict(global_pool=pool) if pool else {}
51 | self.trunk.reset_classifier(0, **reset_kwargs)
52 | prev_chs = self.trunk.num_features
53 |
54 | head_layers = OrderedDict()
55 | if pool == "abs_attn":
56 | head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
57 | prev_chs = embed_dim
58 | elif pool == "rot_attn":
59 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
60 | prev_chs = embed_dim
61 | else:
62 | assert proj, "projection layer needed if non-attention pooling is used."
63 |
64 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
65 | if proj == "linear":
66 | head_layers["drop"] = nn.Dropout(drop)
67 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
68 | elif proj == "mlp":
69 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
70 |
71 | self.head = nn.Sequential(head_layers)
72 |
73 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
74 | """lock modules
75 | Args:
76 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
77 | """
78 | if not unlocked_groups:
79 | # lock full model
80 | for param in self.trunk.parameters():
81 | param.requires_grad = False
82 | if freeze_bn_stats:
83 | freeze_batch_norm_2d(self.trunk)
84 | else:
85 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
86 | try:
87 | # FIXME import here until API stable and in an official release
88 | from timm.models.helpers import group_parameters, group_modules
89 | except ImportError:
90 | raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`")
91 | matcher = self.trunk.group_matcher()
92 | gparams = group_parameters(self.trunk, matcher)
93 | max_layer_id = max(gparams.keys())
94 | max_layer_id = max_layer_id - unlocked_groups
95 | for group_idx in range(max_layer_id + 1):
96 | group = gparams[group_idx]
97 | for param in group:
98 | self.trunk.get_parameter(param).requires_grad = False
99 | if freeze_bn_stats:
100 | gmodules = group_modules(self.trunk, matcher, reverse=True)
101 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
102 | freeze_batch_norm_2d(self.trunk, gmodules)
103 |
104 | @torch.jit.ignore
105 | def set_grad_checkpointing(self, enable=True):
106 | try:
107 | self.trunk.set_grad_checkpointing(enable)
108 | except Exception as e:
109 | logging.warning("grad checkpointing not supported for this timm image tower, continuing without...")
110 |
111 | def forward(self, x):
112 | x = self.trunk(x)
113 | x = self.head(x)
114 | return x
115 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop
8 |
9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
10 |
11 |
12 | class ResizeMaxSize(nn.Module):
13 |
14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
15 | super().__init__()
16 | if not isinstance(max_size, int):
17 | raise TypeError(f"Size should be int. Got {type(max_size)}")
18 | self.max_size = max_size
19 | self.interpolation = interpolation
20 | self.fn = min if fn == "min" else min
21 | self.fill = fill
22 |
23 | def forward(self, img):
24 | if isinstance(img, torch.Tensor):
25 | height, width = img.shape[:2]
26 | else:
27 | width, height = img.size
28 | scale = self.max_size / float(max(height, width))
29 | if scale != 1.0:
30 | new_size = tuple(round(dim * scale) for dim in (height, width))
31 | img = F.resize(img, new_size, self.interpolation)
32 | pad_h = self.max_size - new_size[0]
33 | pad_w = self.max_size - new_size[1]
34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
35 | return img
36 |
37 |
38 | def _convert_to_rgb(image):
39 | return image.convert("RGB")
40 |
41 |
42 | # class CatGen(nn.Module):
43 | # def __init__(self, num=4):
44 | # self.num = num
45 | # def mixgen_batch(image, text):
46 | # batch_size = image.shape[0]
47 | # index = np.random.permutation(batch_size)
48 |
49 | # cat_images = []
50 | # for i in range(batch_size):
51 | # # image mixup
52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
53 | # # text concat
54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
55 | # text = torch.stack(text)
56 | # return image, text
57 |
58 |
59 | def image_transform(
60 | image_size: int,
61 | is_train: bool,
62 | mean: Optional[Tuple[float, ...]] = None,
63 | std: Optional[Tuple[float, ...]] = None,
64 | resize_longest_max: bool = False,
65 | fill_color: int = 0,
66 | ):
67 | mean = mean or OPENAI_DATASET_MEAN
68 | if not isinstance(mean, (list, tuple)):
69 | mean = (mean,) * 3
70 |
71 | std = std or OPENAI_DATASET_STD
72 | if not isinstance(std, (list, tuple)):
73 | std = (std,) * 3
74 |
75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
77 | image_size = image_size[0]
78 |
79 | normalize = Normalize(mean=mean, std=std)
80 | if is_train:
81 | return Compose(
82 | [
83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84 | _convert_to_rgb,
85 | ToTensor(),
86 | normalize,
87 | ]
88 | )
89 | else:
90 | if resize_longest_max:
91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)]
92 | else:
93 | transforms = [
94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
95 | CenterCrop(image_size),
96 | ]
97 | transforms.extend(
98 | [
99 | _convert_to_rgb,
100 | ToTensor(),
101 | normalize,
102 | ]
103 | )
104 | return Compose(transforms)
105 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py:
--------------------------------------------------------------------------------
1 | # Based on EVA, BEIT, timm and DeiT code bases
2 | # https://github.com/baaivision/EVA
3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4 | # https://github.com/microsoft/unilm/tree/master/beit
5 | # https://github.com/facebookresearch/deit/
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | # not tested yet
9 | import math
10 | from transformers import CLIPImageProcessor
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.utils.checkpoint as checkpoint
16 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
17 | from .eva_clip import create_model_and_transforms, get_model_config
18 | import torch
19 | import torchvision
20 | import time
21 |
22 | from llava.utils import rank0_print
23 |
24 |
25 | class EvaViTWrapper(nn.Module):
26 | def __init__(self, vision_tower, args, delay_load=False):
27 | super().__init__()
28 |
29 | self.is_loaded = False
30 | self.vision_tower_name = vision_tower
31 | self.pretrained = args.vision_tower_pretrained
32 | self.args = args
33 |
34 | self.select_layer = args.mm_vision_select_layer
35 | if self.select_layer < -1:
36 | self.select_layer += 1
37 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
38 |
39 | self.model_config = get_model_config(self.vision_tower_name)
40 |
41 | if not delay_load:
42 | rank0_print(f"Loading vision tower: {vision_tower}")
43 | self.load_model()
44 | elif getattr(args, "unfreeze_mm_vision_tower", False):
45 | # TODO: better detector is needed.
46 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
47 | self.load_model()
48 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
49 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
50 | self.load_model()
51 |
52 | def load_model(self):
53 | rank0_print(f"Loading: {self.vision_tower_name}")
54 | rank0_print(f"Pretrained: {self.pretrained}")
55 | time_start = time.time()
56 | model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16")
57 | time_end = time.time()
58 | rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s")
59 | self.device = next(model.parameters()).device
60 | self.dtype = next(model.parameters()).dtype
61 | if self.device.type != "meta":
62 | model = model.to("cuda")
63 | self.vision_tower = model.visual
64 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
65 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
66 | self.resize_transform_size = resize_transform.size
67 | self.image_processor = CLIPImageProcessor.from_pretrained(
68 | "openai/clip-vit-large-patch14",
69 | crop_size=resize_transform.size,
70 | size={"shortest_edge": resize_transform.size},
71 | image_mean=list(normalize_transform.mean),
72 | image_std=list(normalize_transform.std),
73 | )
74 | rank0_print(f"Loaded image processor: {self.image_processor}")
75 | self.vision_tower.requires_grad_(False)
76 | self.is_loaded = True
77 |
78 | def feature_select(self, image_features):
79 | select_feature_type = self.select_feature
80 |
81 | # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
82 | # select_every_k_layer = len(image_features) // 4
83 | # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)
84 | # select_feature_type = select_feature_type.replace("slicefour_", "")
85 | # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
86 | # select_layers = [-1, -4, -7, -10, 6]
87 | # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1)
88 | # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
89 | # else:
90 | # image_features = image_features[self.select_layer]
91 |
92 | if select_feature_type == "patch":
93 | image_features = image_features[:, 1:]
94 | elif select_feature_type == "cls_patch":
95 | image_features = image_features
96 | else:
97 | raise ValueError(f"Unexpected select feature: {select_feature_type}")
98 | return image_features
99 |
100 | def train(self, mode=True):
101 | self.training = mode
102 |
103 | if self.is_loaded:
104 | self.vision_tower.eval()
105 |
106 | def forward(self, images):
107 | if type(images) is list:
108 | image_features = []
109 | for image in images:
110 | image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True)
111 | image_features = self.feature_select(image_features).to(self.dtype)
112 | image_features.append(image_features)
113 | else:
114 | image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True)
115 | image_features = self.feature_select(image_features).to(self.dtype)
116 |
117 | return image_features
118 |
119 | @property
120 | def dummy_feature(self):
121 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
122 |
123 | @property
124 | def hidden_size(self):
125 | return self.model_config["vision_cfg"]["width"]
126 |
127 | @property
128 | def num_patches(self):
129 | return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2
130 |
131 | @property
132 | def num_patches_per_side(self):
133 | return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
134 |
135 | @property
136 | def config(self):
137 | return self.model_config
138 |
139 | @property
140 | def image_size(self):
141 | return self.model_config["vision_cfg"]["image_size"]
142 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .eva_clip_processors import EvaClipImageTrainProcessor
5 | from .eva_vit import EVAEncoderWrapper
6 | from .factory import list_models, add_model_config, get_model_config
7 |
8 | from llava.utils import rank0_print
9 |
10 |
11 | class EvaClipVisionTower(nn.Module):
12 | def __init__(self, vision_tower, args, delay_load=False):
13 | super().__init__()
14 |
15 | self.is_loaded = False
16 | self.vision_tower_name = vision_tower
17 | self.vision_tower_pretrained = args.vision_tower_pretrained
18 | self.config = get_model_config(vision_tower)
19 |
20 | if not delay_load:
21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}")
22 | self.load_model()
23 | elif getattr(args, "unfreeze_mm_vision_tower", False):
24 | # TODO: better detector is needed.
25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
26 | self.load_model()
27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
29 | self.load_model()
30 | else:
31 | self.cfg_only = self.config
32 |
33 | def load_model(self, device_map=None):
34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}")
35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"])
36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config)
37 | rank0_print(f"Loaded image processor: {self.image_processor}")
38 | self.vision_tower.requires_grad_(False)
39 | self.is_loaded = True
40 |
41 | def forward(self, images):
42 | if type(images) is list:
43 | image_features = []
44 | for image in images:
45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype)
46 | image_features.append(image_feature)
47 | else:
48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
49 |
50 | return image_features
51 |
52 | @property
53 | def dtype(self):
54 | return self.vision_tower.dtype
55 |
56 | @property
57 | def device(self):
58 | return self.vision_tower.device
59 |
60 | @property
61 | def hidden_size(self):
62 | return self.config["vision_cfg"]["width"]
63 |
64 | @property
65 | def num_patches(self):
66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2
67 |
68 | @property
69 | def num_patches_per_side(self):
70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]
71 |
72 | @property
73 | def image_size(self):
74 | return self.config["vision_cfg"]["image_size"]
75 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py:
--------------------------------------------------------------------------------
1 | """
2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3 | """
4 |
5 | from torchvision import transforms
6 | from torchvision.transforms.functional import InterpolationMode
7 | from transformers.image_processing_utils import BatchFeature
8 | from PIL import Image
9 | from transformers.image_transforms import convert_to_rgb
10 |
11 |
12 | class BaseProcessor:
13 | def __init__(self):
14 | self.transform = lambda x: x
15 | return
16 |
17 | def __call__(self, item):
18 | return self.transform(item)
19 |
20 |
21 | class EvaClipImageBaseProcessor(BaseProcessor):
22 | def __init__(self, mean=None, std=None):
23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
25 |
26 | self.normalize = transforms.Normalize(self.mean, self.std)
27 |
28 | @property
29 | def image_mean(self):
30 | return self.mean
31 |
32 |
33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
35 | super().__init__(mean=mean, std=std)
36 |
37 | self.transform = transforms.Compose(
38 | [
39 | convert_to_rgb,
40 | transforms.Resize(
41 | image_size,
42 | interpolation=InterpolationMode.BICUBIC,
43 | ),
44 | transforms.CenterCrop(image_size),
45 | transforms.ToTensor(),
46 | self.normalize,
47 | ]
48 | )
49 |
50 | self.image_size = image_size
51 |
52 | def preprocess(self, images, return_tensors):
53 | if isinstance(images, Image.Image):
54 | images = [images]
55 | else:
56 | assert isinstance(images, list)
57 |
58 | transformed_images = [self.transform(image).numpy() for image in images]
59 | data = {"pixel_values": transformed_images}
60 |
61 | return BatchFeature(data=data, tensor_type=return_tensors)
62 |
63 | def __call__(self, item):
64 | return self.transform(item)
65 |
66 | @property
67 | def crop_size(self):
68 | return {"height": self.image_size, "width": self.image_size}
69 |
70 | @property
71 | def size(self):
72 | return {"shortest_edge": self.image_size}
73 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | from copy import deepcopy
7 | from pathlib import Path
8 | from typing import Optional, Tuple, Union, Dict, Any
9 | import torch
10 |
11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
13 |
14 |
15 | def _natural_key(string_):
16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
17 |
18 |
19 | def _rescan_model_configs():
20 | global _MODEL_CONFIGS
21 |
22 | config_ext = (".json",)
23 | config_files = []
24 | for config_path in _MODEL_CONFIG_PATHS:
25 | if config_path.is_file() and config_path.suffix in config_ext:
26 | config_files.append(config_path)
27 | elif config_path.is_dir():
28 | for ext in config_ext:
29 | config_files.extend(config_path.glob(f"*{ext}"))
30 |
31 | for cf in config_files:
32 | with open(cf, "r", encoding="utf8") as f:
33 | model_cfg = json.load(f)
34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
35 | _MODEL_CONFIGS[cf.stem] = model_cfg
36 |
37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
38 |
39 |
40 | _rescan_model_configs() # initial populate of model config registry
41 |
42 |
43 | def list_models():
44 | """enumerate available model architectures based on config files"""
45 | return list(_MODEL_CONFIGS.keys())
46 |
47 |
48 | def add_model_config(path):
49 | """add model config path or file and update registry"""
50 | if not isinstance(path, Path):
51 | path = Path(path)
52 | _MODEL_CONFIG_PATHS.append(path)
53 | _rescan_model_configs()
54 |
55 |
56 | def get_model_config(model_name):
57 | if model_name in _MODEL_CONFIGS:
58 | return deepcopy(_MODEL_CONFIGS[model_name])
59 | else:
60 | return None
61 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1536,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 48,
6 | "width": 5120,
7 | "head_width": 128,
8 | "mlp_ratio": 5,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-18b-14-x",
11 | "drop_path_rate": 0,
12 | "qkv_bias": false,
13 | "xattn": true,
14 | "postnorm": true,
15 | "fusedLN": false,
16 | "use_rms_norm": true
17 | },
18 | "text_cfg": {
19 | "context_length": 77,
20 | "vocab_size": 49408,
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "xattn": false,
25 | "fusedLN": false
26 | }
27 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 448,
5 | "layers": 32,
6 | "width": 4096,
7 | "head_width": 128,
8 | "mlp_ratio": 5,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-8b-14-plus-x",
11 | "drop_path_rate": 0,
12 | "qkv_bias": false,
13 | "xattn": true,
14 | "postnorm": false,
15 | "fusedLN": false,
16 | "use_rms_norm": true
17 | },
18 | "text_cfg": {
19 | "context_length": 77,
20 | "vocab_size": 49408,
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "xattn": false,
25 | "fusedLN": false
26 | }
27 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1280,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 4096,
7 | "head_width": 128,
8 | "mlp_ratio": 5,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-8b-14-x",
11 | "drop_path_rate": 0,
12 | "qkv_bias": false,
13 | "xattn": true,
14 | "postnorm": false,
15 | "fusedLN": false,
16 | "use_rms_norm": true
17 | },
18 | "text_cfg": {
19 | "context_length": 77,
20 | "vocab_size": 49408,
21 | "width": 1280,
22 | "heads": 20,
23 | "layers": 32,
24 | "xattn": false,
25 | "fusedLN": false
26 | }
27 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "eva_model_name": "eva-clip-b-16",
9 | "ls_init_value": 0.1,
10 | "drop_path_rate": 0.0
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 1024,
19 | "heads": 16,
20 | "layers": 24,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0.4,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 768,
19 | "heads": 12,
20 | "layers": 12,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "head_width": 64,
8 | "patch_size": 16,
9 | "mlp_ratio": 2.6667,
10 | "eva_model_name": "eva-clip-b-16-X",
11 | "drop_path_rate": 0.0,
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 512,
24 | "heads": 8,
25 | "layers": 12,
26 | "xattn": true,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14-336",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 448,
5 | "layers": 77,
6 | "width": 2304,
7 | "head_width": 144,
8 | "mlp_ratio": 10.9722,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-10b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": false,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 77,
6 | "width": 2304,
7 | "head_width": 144,
8 | "mlp_ratio": 10.9722,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-10b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": false,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/hf_vision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor
5 | from llava.utils import rank0_print
6 |
7 |
8 | class HFVisionTower(nn.Module):
9 | def __init__(self, vision_tower, args, delay_load=False):
10 | super().__init__()
11 |
12 | self.is_loaded = False
13 |
14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1)
15 | self.select_layer = args.mm_vision_select_layer
16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
17 |
18 | if not delay_load:
19 | self.load_model()
20 | else:
21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
22 |
23 | def load_model(self):
24 | try:
25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name)
26 | except Exception as e:
27 | if "448" in self.vision_tower_name:
28 | image_size = 448
29 | # use image processor with conig
30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size)
31 | else:
32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
33 | rank0_print(f"Loaded image processor: {self.image_processor}")
34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
35 | self.device = self.vision_tower.device
36 | self.dtype = self.vision_tower.dtype
37 | self.config = self.vision_tower.config
38 |
39 | if hasattr(self.vision_tower, "vision_model"):
40 | self.vision_tower = self.vision_tower.vision_model
41 | self.vision_tower.requires_grad_(False)
42 | # self.vision_tower.eval()
43 | self.is_loaded = True
44 |
45 | def feature_select(self, image_forward_outs):
46 | select_feature_type = self.select_feature
47 |
48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4
50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
51 | select_feature_type = select_feature_type.replace("slicefour_", "")
52 | else:
53 | image_features = image_forward_outs.hidden_states[self.select_layer]
54 |
55 | if select_feature_type == "patch":
56 | image_features = image_features[:, 1:]
57 | elif select_feature_type == "cls_patch":
58 | image_features = image_features
59 | else:
60 | raise ValueError(f"Unexpected select feature: {select_feature_type}")
61 | return image_features
62 |
63 | def forward(self, images):
64 | if type(images) is list:
65 | image_features = []
66 | for image in images:
67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
68 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
69 | image_features.append(image_feature)
70 | else:
71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
72 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
73 |
74 | return image_features
75 |
76 | @property
77 | def dummy_feature(self):
78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
79 |
80 | # @property
81 | # def dtype(self):
82 | # return self.vision_tower.dtype
83 |
84 | # @property
85 | # def device(self):
86 | # return self.vision_tower.device
87 |
88 | @property
89 | def hidden_size(self):
90 | try:
91 | _hidden_size = self.config.hidden_size
92 | except:
93 | _hidden_size = self.config.vision_config.hidden_size
94 | if "slicefour" in self.select_feature:
95 | _hidden_size *= 4
96 | return _hidden_size
97 |
98 | @property
99 | def num_patches(self):
100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2
101 | if "cls_patch" in self.select_feature:
102 | _num_patches += 1
103 | return _num_patches
104 |
105 | @property
106 | def num_patches_per_side(self):
107 | return self.config.image_size // self.config.patch_size
108 |
109 | @property
110 | def image_size(self):
111 | return self.config.image_size
112 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/imagebind.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPImageProcessor
5 |
6 | try:
7 | from imagebind.models import imagebind_model
8 | from imagebind.models.imagebind_model import ModalityType
9 | from imagebind.data import load_and_transform_audio_data
10 | except ImportError:
11 | pass
12 |
13 |
14 | class ImageBindWrapper(nn.Module):
15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False):
16 | super().__init__()
17 |
18 | self.is_loaded = False
19 |
20 | self.vision_tower_name = vision_tower
21 | self.select_layer = select_layer
22 | self.select_feature = select_feature
23 |
24 | if not delay_load:
25 | self.load_model()
26 |
27 | def load_model(self):
28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True)
30 | for p in self.vision_tower.parameters():
31 | p.requires_grad = False
32 | self.vision_tower.eval()
33 | self.is_loaded = True
34 |
35 | def train(self, mode=True):
36 | self.training = mode
37 |
38 | if self.is_loaded:
39 | self.vision_tower.eval()
40 |
41 | @torch.no_grad()
42 | def forward(self, x):
43 | if type(x) == dict:
44 | if x["audios"] is not None:
45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()}
46 | embeddings = self.vision_tower(inputs)
47 | audio_embedding = embeddings[ModalityType.AUDIO]
48 | return audio_embedding.unsqueeze(1)
49 | else:
50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)}
51 | embeddings = self.vision_tower(inputs)
52 | vision_embedding = embeddings[ModalityType.VISION]
53 | if vision_embedding.ndim == 2:
54 | return vision_embedding.unsqueeze(1)
55 | if vision_embedding.shape[1] == 257:
56 | return vision_embedding[:, 1:]
57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}")
58 |
59 | @property
60 | def dummy_feature(self):
61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype)
62 |
63 | @property
64 | def dtype(self):
65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype
66 |
67 | @property
68 | def device(self):
69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device
70 |
71 | @property
72 | def hidden_size(self):
73 | return 1024
74 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/open_clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import CLIPImageProcessor
4 | from llava.utils import rank0_print
5 |
6 | try:
7 | import open_clip
8 | import torchvision
9 | from open_clip.transformer import _expand_token
10 | except ImportError:
11 | print("OpenCLIP not installed")
12 | open_clip = None
13 |
14 | HIDDEN_SIZE_DICT = {
15 | "ViT-H-14-378-quickgelu": 1280,
16 | }
17 |
18 |
19 | class OpenCLIPVisionTower(nn.Module):
20 | def __init__(self, vision_tower, args, delay_load=False):
21 | super().__init__()
22 |
23 | self.is_loaded = False
24 | self.model_name = vision_tower.replace("open_clip_hub:", "")
25 | self.pretrained = args.vision_tower_pretrained
26 | self.select_layer = args.mm_vision_select_layer
27 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
28 |
29 | if not delay_load:
30 | rank0_print(f"Loading vision tower: {vision_tower}")
31 | self.load_model()
32 | elif getattr(args, "unfreeze_mm_vision_tower", False):
33 | # TODO: better detector is needed.
34 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
35 | self.load_model()
36 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
37 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
38 | self.load_model()
39 |
40 | def load_model(self, device_map="auto"):
41 | rank0_print(f"Loading OpenCLIP model: {self.model_name}")
42 | rank0_print(f"Pretrained: {self.pretrained}")
43 | vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda")
44 |
45 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
46 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
47 | self.resize_transform_size = resize_transform.size # 224 or 384
48 | self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16
49 |
50 | self.image_processor = CLIPImageProcessor.from_pretrained(
51 | "openai/clip-vit-large-patch14",
52 | crop_size=resize_transform.size,
53 | size={"shortest_edge": resize_transform.size},
54 | image_mean=list(normalize_transform.mean),
55 | image_std=list(normalize_transform.std),
56 | )
57 | rank0_print(f"Loaded image processor: {self.image_processor}")
58 | self.vision_tower = vision_tower.visual
59 | self.vision_tower.requires_grad_(False)
60 |
61 | self.is_loaded = True
62 |
63 | def feature_select(self, image_forward_outs):
64 | image_features = image_forward_outs[self.select_layer]
65 | if self.select_feature == "patch":
66 | image_features = image_features[:, 1:]
67 | elif self.select_feature == "cls_patch":
68 | image_features = image_features
69 | elif self.select_feature == "conv_flatten":
70 | image_features = image_features.flatten(2).transpose(1, 2)
71 | else:
72 | raise ValueError(f"Unexpected select feature: {self.select_feature}")
73 | return image_features
74 |
75 | def forward_visual(self, x, output_hidden_states=False):
76 | if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"):
77 | return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer))
78 | else:
79 |
80 | def forward_openclip(self, x: torch.Tensor):
81 | features = []
82 | x = self.conv1(x) # shape = [*, width, grid, grid]
83 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
84 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
85 |
86 | # class embeddings and positional embeddings
87 | x = torch.cat(
88 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
89 | dim=1,
90 | )
91 | # shape = [*, grid ** 2 + 1, width]
92 | x = x + self.positional_embedding.to(x.dtype)
93 |
94 | x = self.patch_dropout(x)
95 | x = self.ln_pre(x)
96 |
97 | x = x.permute(1, 0, 2) # NLD -> LND
98 | for r in self.transformer.resblocks:
99 | x = r(x, attn_mask=None)
100 | features.append(x)
101 | return features
102 |
103 | return forward_openclip(self.vision_tower, x)
104 |
105 | def forward(self, images):
106 | if type(images) is list:
107 | image_features = []
108 | for image in images:
109 | image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True)
110 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
111 | image_features.append(image_feature)
112 | else:
113 | image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True)
114 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
115 |
116 | return image_features
117 |
118 | @property
119 | def dummy_feature(self):
120 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
121 |
122 | @property
123 | def dtype(self):
124 | if hasattr(self.vision_tower, "conv1"):
125 | return self.vision_tower.conv1.weight.dtype
126 | if hasattr(self.vision_tower, "trunk"):
127 | return self.vision_tower.trunk.patch_embed.proj.weight.dtype
128 | raise NotImplementedError
129 |
130 | @property
131 | def device(self):
132 | if hasattr(self.vision_tower, "conv1"):
133 | return self.vision_tower.conv1.weight.device
134 | if hasattr(self.vision_tower, "trunk"):
135 | return self.vision_tower.trunk.patch_embed.proj.weight.device
136 | raise NotImplementedError
137 |
138 | @property
139 | def config(self):
140 | return None
141 |
142 | @property
143 | def hidden_size(self):
144 | if self.model_name in HIDDEN_SIZE_DICT:
145 | return HIDDEN_SIZE_DICT[self.model_name]
146 | else:
147 | raise NotImplementedError
148 |
149 | @property
150 | def num_patches(self):
151 | image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0]
152 | _num_patches = (image_size // self.patch_size) ** 2
153 | if "cls_patch" in self.select_feature:
154 | _num_patches += 1
155 | return _num_patches
156 |
157 | @property
158 | def image_size(self):
159 | return self.resize_transform_size
160 |
161 | @property
162 | def num_patches_per_side(self):
163 | return self.resize_transform_size // self.patch_size
164 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 | from .pooler_projector import PoolerProjector
6 |
7 |
8 | class IdentityMap(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x, *args, **kwargs):
13 | return x
14 |
15 | @property
16 | def config(self):
17 | return {"mm_projector_type": "identity"}
18 |
19 |
20 | class SimpleResBlock(nn.Module):
21 | def __init__(self, channels):
22 | super().__init__()
23 | self.pre_norm = nn.LayerNorm(channels)
24 |
25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
26 |
27 | def forward(self, x):
28 | x = self.pre_norm(x)
29 | return x + self.proj(x)
30 |
31 |
32 | def build_vision_projector(config, delay_load=False, **kwargs):
33 | projector_type = getattr(config, "mm_projector_type", "linear")
34 |
35 | if projector_type == "linear":
36 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
37 |
38 | if projector_type == "pooler":
39 | return PoolerProjector(config, kwargs["vision_cfg"])
40 |
41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
42 | if mlp_gelu_match:
43 | mlp_depth = int(mlp_gelu_match.group(1))
44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
45 | for _ in range(1, mlp_depth):
46 | modules.append(nn.GELU())
47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
48 | return nn.Sequential(*modules)
49 |
50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
51 | if mlp_gelu_resnet_match:
52 | mlp_depth = int(mlp_gelu_resnet_match.group(1))
53 | res_depth = int(mlp_gelu_resnet_match.group(2))
54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
55 | for _ in range(1, mlp_depth):
56 | modules.append(nn.GELU())
57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
58 | for _ in range(res_depth):
59 | modules.append(SimpleResBlock(config.hidden_size))
60 | return nn.Sequential(*modules)
61 |
62 | if projector_type == "identity":
63 | return IdentityMap()
64 |
65 | raise ValueError(f"Unknown projector type: {projector_type}")
66 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/pooler_projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import math
5 |
6 | from transformers.models.clip.modeling_clip import CLIPVisionModel
7 |
8 |
9 | class PoolerProjector(nn.Module):
10 | def __init__(self, config, vision_cfg):
11 | super().__init__()
12 | self._config = config
13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size
14 |
15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2)
16 |
17 | self.proj = nn.Sequential(
18 | nn.GELU(),
19 | nn.Linear(config.hidden_size, config.hidden_size),
20 | )
21 |
22 | def forward(self, x, *args, **kwargs):
23 | height = width = self.hw
24 | assert height * width == x.shape[1]
25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
26 | x = self.conv_pool(x)
27 | x = x.flatten(2).transpose(1, 2)
28 | x = self.proj(x)
29 | return x
30 |
31 | @property
32 | def config(self):
33 | return {"mm_projector_type": "pooler"}
34 |
--------------------------------------------------------------------------------
/llava/model/multimodal_resampler/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .masked_drop import MaskedDrop
4 | from .spatial_pool import SpatialPool
5 | from .perceiver import PerceiverResampler
6 | from .qformer import Qformer
7 |
8 |
9 | class IdentityMap(torch.nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, x, *args, **kwargs):
14 | return x
15 |
16 | @property
17 | def config(self):
18 | return {"mm_resampler_type": None}
19 |
20 |
21 | def build_vision_resampler(model_args, delay_load=False, **kwargs):
22 | resampler_type = getattr(model_args, "mm_resampler_type", None)
23 | if resampler_type == "masked_drop":
24 | return MaskedDrop(model_args)
25 | elif resampler_type == "spatial_pool":
26 | return SpatialPool(model_args, **kwargs)
27 | elif resampler_type == "perceiver":
28 | return PerceiverResampler(model_args, **kwargs)
29 | elif resampler_type == "qformer":
30 | return Qformer(model_args, **kwargs)
31 | elif resampler_type is None:
32 | return IdentityMap()
33 |
34 | raise ValueError(f"Unknown resampler type: {resampler_type}")
35 |
--------------------------------------------------------------------------------
/llava/model/multimodal_resampler/masked_drop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import random
5 |
6 |
7 | class MaskedDrop(nn.Module):
8 | def __init__(self, model_args):
9 | super().__init__()
10 |
11 | self.mode = model_args.mm_mask_drop_mode
12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13 | self.ratio = model_args.mm_mask_drop_ratio
14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16 |
17 | def forward(self, image_features, *args, **kwargs):
18 |
19 | if not self.training:
20 | return image_features
21 |
22 | if self.skip_percentage > random.random():
23 | return image_features
24 |
25 | masked_features = []
26 |
27 | for image_feature in image_features:
28 | num_tokens = image_feature.shape[0]
29 | if self.mode == "fixed":
30 | num_keep = int(num_tokens * self.ratio)
31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
32 | elif self.mode == "range":
33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
35 | elif self.mode == "cls_only":
36 | masked_features.append(image_feature[0:1])
37 | else:
38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}")
39 |
40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]):
41 | masked_features = torch.stack(masked_features, dim=0)
42 |
43 | return masked_features
44 |
45 | @property
46 | def config(self):
47 | return {
48 | "mm_resampler_type": "masked_drop",
49 | "mm_mask_drop_mode": self.mode,
50 | "mm_mask_drop_skip_percentage": self.skip_percentage,
51 | "mm_mask_drop_ratio": self.ratio,
52 | "mm_mask_drop_ratio_upper": self.ratio_upper,
53 | "mm_mask_drop_ratio_lower": self.ratio_lower,
54 | }
55 |
56 | def random_masking(self, x, len_keep):
57 | """
58 | Perform per-sample random masking by per-sample shuffling.
59 | Per-sample shuffling is done by argsort random noise.
60 | x: [N, L, D], sequence
61 | """
62 | N, L, D = x.shape # batch, length, dim
63 |
64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
65 |
66 | # sort noise for each sample
67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
68 | ids_restore = torch.argsort(ids_shuffle, dim=1)
69 |
70 | # keep the first subset
71 | ids_keep = ids_shuffle[:, :len_keep]
72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
73 |
74 | # generate the binary mask: 0 is keep, 1 is remove
75 | mask = torch.ones([N, L], device=x.device)
76 | mask[:, :len_keep] = 0
77 | # unshuffle to get the binary mask
78 | mask = torch.gather(mask, dim=1, index=ids_restore)
79 |
80 | return x_masked, mask, ids_restore
81 |
--------------------------------------------------------------------------------
/llava/model/multimodal_resampler/perceiver.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/lucidrains/flamingo-pytorch
3 | """
4 |
5 | import torch
6 | from einops import rearrange, repeat
7 |
8 | try:
9 | from einops_exts import rearrange_many
10 | except:
11 | pass
12 |
13 | from torch import einsum, nn
14 |
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 |
20 | def FeedForward(dim, mult=4):
21 | inner_dim = int(dim * mult)
22 | return nn.Sequential(
23 | nn.LayerNorm(dim),
24 | nn.Linear(dim, inner_dim, bias=False),
25 | nn.GELU(),
26 | nn.Linear(inner_dim, dim, bias=False),
27 | )
28 |
29 |
30 | class PerceiverAttention(nn.Module):
31 | def __init__(self, *, dim, dim_head=64, heads=8):
32 | super().__init__()
33 | self.scale = dim_head**-0.5
34 | self.heads = heads
35 | inner_dim = dim_head * heads
36 |
37 | self.norm_media = nn.LayerNorm(dim)
38 | self.norm_latents = nn.LayerNorm(dim)
39 |
40 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
43 |
44 | def forward(self, x, latents):
45 | """
46 | Args:
47 | x (torch.Tensor): image features
48 | shape (b, T, n1, D)
49 | latent (torch.Tensor): latent features
50 | shape (b, T, n2, D)
51 | """
52 | x = self.norm_media(x)
53 | latents = self.norm_latents(latents)
54 |
55 | h = self.heads
56 |
57 | q = self.to_q(latents)
58 | kv_input = torch.cat((x, latents), dim=-2)
59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61 | q = q * self.scale
62 |
63 | # attention
64 | sim = einsum("... i d, ... j d -> ... i j", q, k)
65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
66 | attn = sim.softmax(dim=-1)
67 |
68 | out = einsum("... i j, ... j d -> ... i d", attn, v)
69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
70 | return self.to_out(out)
71 |
72 |
73 | class PerceiverResamplerModule(nn.Module):
74 | def __init__(
75 | self,
76 | *,
77 | dim,
78 | depth=6,
79 | dim_head=64,
80 | heads=8,
81 | num_latents=64,
82 | max_num_media=None,
83 | max_num_frames=None,
84 | ff_mult=4,
85 | ):
86 | super().__init__()
87 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
90 |
91 | self.layers = nn.ModuleList([])
92 | for _ in range(depth):
93 | self.layers.append(
94 | nn.ModuleList(
95 | [
96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(),
98 | ]
99 | )
100 | )
101 |
102 | self.norm = nn.LayerNorm(dim)
103 |
104 | def forward(self, x):
105 | """
106 | Args:
107 | x (torch.Tensor): image features
108 | shape (b, T, F, v, D)
109 | Returns:
110 | shape (b, T, n, D) where n is self.num_latents
111 | """
112 | b, T, F, v = x.shape[:4]
113 |
114 | # frame and media time embeddings
115 | if exists(self.frame_embs):
116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
117 | x = x + frame_embs
118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
119 | if exists(self.media_time_embs):
120 | x = x + self.media_time_embs[:T]
121 |
122 | # blocks
123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
124 | for attn, ff in self.layers:
125 | latents = attn(x, latents) + latents
126 | latents = ff(latents) + latents
127 | return self.norm(latents)
128 |
129 |
130 | class PerceiverResampler(nn.Module):
131 | def __init__(self, model_args, vision_tower):
132 | super().__init__()
133 |
134 | self.depth = model_args.mm_perceiver_depth
135 | self.num_latents = model_args.mm_perceiver_latents
136 | self.ff_mult = model_args.mm_perceiver_ff_mult
137 | self.pretrained = model_args.mm_perceiver_pretrained
138 |
139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult)
140 |
141 | if self.pretrained is not None:
142 | self.load_state_dict(torch.load(self.pretrained))
143 |
144 | def forward(self, image_features, *args, **kwargs):
145 | return self.perceiver(image_features[:, None, None]).squeeze(1)
146 |
147 | @property
148 | def config(self):
149 | return {
150 | "mm_resampler_type": "perceiver",
151 | "mm_perceiver_depth": self.depth,
152 | "mm_perceiver_latents": self.num_latents,
153 | "mm_perceiver_ff_mult": self.ff_mult,
154 | "mm_perceiver_pretrained": self.pretrained,
155 | }
156 |
--------------------------------------------------------------------------------
/llava/model/multimodal_resampler/spatial_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class SpatialPool(nn.Module):
7 | def __init__(self, model_args, vision_tower):
8 | super().__init__()
9 |
10 | self.mode = model_args.mm_spatial_pool_mode
11 | self.stride = model_args.mm_spatial_pool_stride
12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size)
13 |
14 | if self.mode == "average":
15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
16 | elif self.mode == "max":
17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
18 | elif self.mode == "conv":
19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
20 | else:
21 | raise ValueError(f"Unknown pooling mode: {self.pool}.")
22 |
23 | def forward(self, image_features, images, *args, **kwargs):
24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
25 | ori_H = int(ori_W * images.shape[2] // images.shape[3])
26 |
27 | B, _, F = image_features.shape
28 |
29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
30 | image_features_spatial_pool = self.pool(image_features_spatial)
31 |
32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
33 |
34 | @property
35 | def config(self):
36 | return {
37 | "mm_resampler_type": "spatial_pool",
38 | "mm_spatial_pool_stride": self.stride,
39 | "mm_spatial_pool_mode": self.mode,
40 | "mm_spatial_pool_out_channels": self.out_channels,
41 | }
42 |
43 | @property
44 | def hidden_size(self):
45 | return self.out_channels
46 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if "llava" in config and "llava" not in cfg.model_type:
7 | assert cfg.model_type == "llama"
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = "LlavaLlamaForCausalLM"
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/model/vision_mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from transformers.models.llama.modeling_llama import LlamaRMSNorm
4 |
5 | import numpy as np
6 |
7 | class ProxyInitializer(nn.Module):
8 | def __init__(self, d_model, d_hidden, n):
9 | """
10 | d_model: embedding dimension for both full tokens and proxy tokens
11 | """
12 | super().__init__()
13 | self.Wk = nn.Linear(d_model, d_hidden, bias=False)
14 |
15 | self.proxyv_tokens = nn.Parameter(torch.randn((n, d_hidden))/torch.sqrt(torch.tensor(d_hidden*4)))
16 |
17 | def forward(self, full_tokens, compress_reduce_factor, single_crop_len):
18 | """
19 | full_tokens: Tensor of shape (batch_size, N, d_model)
20 | proxyv_tokens: Tensor of shape (batch_size, n, d_model)
21 |
22 | Returns:
23 | proxyv_out: updated proxy tokens (batch_size, n, d_model)
24 | attn: attention matrix from proxy->full (batch_size, n, N)
25 | """
26 | proxyv_tokens = self.proxyv_tokens.unsqueeze(0).repeat(full_tokens.shape[0], 1, 1)
27 | Q = proxyv_tokens
28 | K = self.Wk(full_tokens)
29 | V = full_tokens
30 |
31 | d_model = Q.shape[-1]
32 | attn_logits = torch.bmm(Q, K.transpose(1, 2)) / (d_model ** 0.5)
33 |
34 | attn = nn.functional.softmax(attn_logits, dim=-1)
35 |
36 | attn_logits_T = attn_logits.transpose(1, 2)
37 | attn_T = nn.functional.softmax(attn_logits_T, dim=-1)
38 |
39 | # Update the proxy tokens by pooling full tokens
40 | proxyv_out = torch.bmm(attn, V)
41 | return proxyv_out, attn_T
42 |
43 | def splat_proxyv_tokens(proxyv_tokens, attn):
44 | return torch.bmm(attn, proxyv_tokens)
45 |
46 |
47 | class VisionMLP(nn.Module):
48 | def __init__(self, config, intermediate_size):
49 | super().__init__()
50 | self.context_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
51 | self.input_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
52 | self.proj = nn.Sequential(
53 | nn.Linear(intermediate_size*2, intermediate_size, bias=False),
54 | nn.SiLU(),
55 | nn.Linear(intermediate_size, config.hidden_size, bias=False)
56 | )
57 | self.layernorm_post = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
58 |
59 | def forward(self, image_full, image_compress, compress_reduce_factor, per_crop_token_len=576, learn_proxy=False, proxyv_attn=None):
60 | side_len_full = int(per_crop_token_len**0.5)
61 | side_len_compress = side_len_full // compress_reduce_factor
62 |
63 | num_image_crops = image_full.shape[1]//per_crop_token_len
64 | bs = image_full.shape[0]
65 |
66 | if learn_proxy:
67 | image_full = image_full.view(bs*num_image_crops, side_len_full*side_len_full, -1)
68 | image_compress = image_compress.view(bs*num_image_crops, side_len_compress*side_len_compress, -1)
69 | image_compress = splat_proxyv_tokens(image_compress, proxyv_attn)
70 | image_compress = self.context_proj(image_compress)
71 | else:
72 | image_full = image_full.view(bs*num_image_crops, side_len_full, side_len_full, -1)
73 | image_compress = image_compress.view(bs*num_image_crops, side_len_compress, side_len_compress, -1)
74 | image_compress = self.context_proj(image_compress)
75 | image_compress = image_compress.repeat_interleave(compress_reduce_factor, 1).repeat_interleave(compress_reduce_factor, 2)
76 | residual = image_full
77 | image_full = self.input_proj(image_full)
78 | image_full = torch.cat([image_full, image_compress], -1)
79 | image_full = self.layernorm_post(self.proj(image_full) + residual)
80 |
81 | image_full = image_full.view(bs, num_image_crops*side_len_full*side_len_full, -1)
82 |
83 | return image_full
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | padding_mask: Optional[torch.Tensor] = None,
25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26 | if output_attentions:
27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.")
28 |
29 | bsz, q_len, _ = hidden_states.size()
30 |
31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim)
34 |
35 | kv_seq_len = key_states.shape[-2]
36 | if past_key_value is not None:
37 | kv_seq_len += past_key_value[0].shape[-2]
38 |
39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
41 |
42 | if past_key_value is not None:
43 | # reuse k, v
44 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
45 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
46 |
47 | past_key_value = (key_states, value_states) if use_cache else None
48 |
49 | # repeat k/v heads if n_kv_heads < n_heads
50 | key_states = repeat_kv(key_states, self.num_key_value_groups)
51 | value_states = repeat_kv(value_states, self.num_key_value_groups)
52 |
53 | # Transform the data into the format required by flash attention
54 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
56 | key_padding_mask = attention_mask
57 |
58 | if key_padding_mask is None:
59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
61 | max_s = q_len
62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
63 | output = output.view(bsz, q_len, -1)
64 | else:
65 | qkv = qkv.reshape(bsz, q_len, -1)
66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True)
69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
70 | output = pad_input(output_unpad, indices, bsz, q_len)
71 |
72 | return self.o_proj(output), None, past_key_value
73 |
74 |
75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
76 | # requires the attention mask to be the same as the key_padding_mask
77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
78 | # [bsz, seq_len]
79 | return attention_mask
80 |
81 |
82 | def replace_llama_attn_with_flash_attn():
83 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
84 | if cuda_major < 8:
85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593")
86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
88 |
--------------------------------------------------------------------------------
/llava/train/llava_trainer_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 |
4 | from llava.train.llava_trainer import LLaVATrainer
5 |
6 |
7 | class LLaVAEvalTrainer(LLaVATrainer):
8 | def evaluate(self, evaluate_args):
9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \
10 | --model {evaluate_args.model} \
11 | --model_args {evaluate_args.model_args} \
12 | --tasks {evaluate_args.task_names} \
13 | --batch_size {evaluate_args.batch_size} \
14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \
15 | --output_path {evaluate_args.output_path}"
16 | if evaluate_args.limit:
17 | cmd += f" --limit {evaluate_args.limit}"
18 | if evaluate_args.num_fewshot:
19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}"
20 | if evaluate_args.gen_kwargs != "":
21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}"
22 | if evaluate_args.log_samples:
23 | cmd += f" --log_samples"
24 | else:
25 | assert False, "Please log samples so that the result can be parsed"
26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True)
27 | try:
28 | result_file_index_start = results.stdout.index("Saved samples to ")
29 | result_file_index_end = results.stdout.index(f".json")
30 | result_file_index_start += len("Saved samples to ")
31 | file = results.stdout[result_file_index_start:result_file_index_end]
32 | except:
33 | result_file_index_start = results.stderr.index("Saved samples to ")
34 | result_file_index_end = results.stderr.index(f".json")
35 | result_file_index_start += len("Saved samples to ")
36 | file = results.stderr[result_file_index_start:result_file_index_end]
37 | file = file.split("/")[:-1]
38 | file = "/".join(file) + "/results.json"
39 | with open(file, "r") as f:
40 | lmms_eval_results = json.load(f)
41 | result_dict = {}
42 | tasks_list = evaluate_args.task_names.split(",")
43 | for task in tasks_list:
44 | task_results = lmms_eval_results["results"][task]
45 | for k, v in task_results.items():
46 | if k != "alias" and "stderr" not in k:
47 | metric = k.split(",")[0]
48 | result_dict[f"{task}_{metric}"] = v
49 | return result_dict
50 |
51 | """def evaluate(self, evaluate_args):
52 | initialize_tasks()
53 | tasks_list = evaluate_args.task_names.split(",")
54 | result_dict = {}
55 | results = evaluator.simple_evaluate(
56 | model=evaluate_args.model,
57 | model_args=evaluate_args.model_args,
58 | tasks=tasks_list,
59 | num_fewshot=evaluate_args.num_fewshot,
60 | batch_size=evaluate_args.batch_size,
61 | device=evaluate_args.device,
62 | limit=evaluate_args.limit,
63 | check_integrity=evaluate_args.check_integrity,
64 | show_task_to_terminal=evaluate_args.show_task_to_terminal,
65 | log_samples=evaluate_args.log_samples,
66 | gen_kwargs=evaluate_args.gen_kwargs,
67 | cli_args=evaluate_args,
68 | )
69 | for task in tasks_list:
70 | task_results = results["results"][task]
71 | for k,v in task_results.items():
72 | if k != "alias" and "stderr" not in k:
73 | metric = k.split(",")[0]
74 | result_dict[f"{task}_{metric}"] = v
75 |
76 | return result_dict"""
77 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | from llava.train.train import train
2 |
3 | if __name__ == "__main__":
4 | train()
5 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 | import numpy as np
7 |
8 | import requests
9 |
10 | from llava.constants import LOGDIR
11 |
12 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
13 | moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content."
14 |
15 | handler = None
16 |
17 | import torch.distributed as dist
18 |
19 | try:
20 | import av
21 | from decord import VideoReader, cpu
22 | except ImportError:
23 | print("Please install pyav to use video processing functions.")
24 |
25 | def process_video_with_decord(video_file, data_args):
26 | vr = VideoReader(video_file, ctx=cpu(0), num_threads=1)
27 | total_frame_num = len(vr)
28 | video_time = total_frame_num / vr.get_avg_fps()
29 | avg_fps = round(vr.get_avg_fps() / data_args.video_fps)
30 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
31 | frame_time = [i/avg_fps for i in frame_idx]
32 |
33 |
34 | if data_args.frames_upbound > 0:
35 | if len(frame_idx) > data_args.frames_upbound or data_args.force_sample:
36 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
37 | frame_idx = uniform_sampled_frames.tolist()
38 | frame_time = [i/vr.get_avg_fps() for i in frame_idx]
39 |
40 | video = vr.get_batch(frame_idx).asnumpy()
41 | frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
42 |
43 | num_frames_to_sample = num_frames = len(frame_idx)
44 | # https://github.com/dmlc/decord/issues/208
45 | vr.seek(0)
46 | return video, video_time, frame_time, num_frames_to_sample
47 |
48 | def process_video_with_pyav(video_file, data_args):
49 | container = av.open(video_file)
50 | # !!! This is the only difference. Using auto threading
51 | container.streams.video[0].thread_type = "AUTO"
52 |
53 | video_frames = []
54 | for packet in container.demux():
55 | if packet.stream.type == 'video':
56 | for frame in packet.decode():
57 | video_frames.append(frame)
58 | total_frame_num = len(video_frames)
59 | video_time = video_frames[-1].time
60 | avg_fps = round(total_frame_num / video_time / data_args.video_fps)
61 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
62 |
63 | if data_args.frames_upbound > 0:
64 | if len(frame_idx) > data_args.frames_upbound:
65 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int)
66 | frame_idx = uniform_sampled_frames.tolist()
67 |
68 |
69 | frames = [video_frames[i] for i in frame_idx]
70 | return np.stack([x.to_ndarray(format="rgb24") for x in frames])
71 |
72 |
73 | def rank0_print(*args):
74 | if dist.is_initialized():
75 | if dist.get_rank() == 0:
76 | print(f"Rank {dist.get_rank()}: ", *args)
77 | else:
78 | print(*args)
79 |
80 |
81 | def rank_print(*args):
82 | if dist.is_initialized():
83 | print(f"Rank {dist.get_rank()}: ", *args)
84 | else:
85 | print(*args)
86 |
87 | def build_logger(logger_name, logger_filename):
88 | global handler
89 |
90 | formatter = logging.Formatter(
91 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
92 | datefmt="%Y-%m-%d %H:%M:%S",
93 | )
94 |
95 | # Set the format of root handlers
96 | if not logging.getLogger().handlers:
97 | logging.basicConfig(level=logging.INFO)
98 | logging.getLogger().handlers[0].setFormatter(formatter)
99 |
100 | # Redirect stdout and stderr to loggers
101 | stdout_logger = logging.getLogger("stdout")
102 | stdout_logger.setLevel(logging.INFO)
103 | sl = StreamToLogger(stdout_logger, logging.INFO)
104 | sys.stdout = sl
105 |
106 | stderr_logger = logging.getLogger("stderr")
107 | stderr_logger.setLevel(logging.ERROR)
108 | sl = StreamToLogger(stderr_logger, logging.ERROR)
109 | sys.stderr = sl
110 |
111 | # Get logger
112 | logger = logging.getLogger(logger_name)
113 | logger.setLevel(logging.INFO)
114 |
115 | # Add a file handler for all loggers
116 | if handler is None:
117 | os.makedirs(LOGDIR, exist_ok=True)
118 | filename = os.path.join(LOGDIR, logger_filename)
119 | handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True)
120 | handler.setFormatter(formatter)
121 |
122 | for name, item in logging.root.manager.loggerDict.items():
123 | if isinstance(item, logging.Logger):
124 | item.addHandler(handler)
125 |
126 | return logger
127 |
128 |
129 | class StreamToLogger(object):
130 | """
131 | Fake file-like stream object that redirects writes to a logger instance.
132 | """
133 |
134 | def __init__(self, logger, log_level=logging.INFO):
135 | self.terminal = sys.stdout
136 | self.logger = logger
137 | self.log_level = log_level
138 | self.linebuf = ""
139 |
140 | def __getattr__(self, attr):
141 | return getattr(self.terminal, attr)
142 |
143 | def write(self, buf):
144 | temp_linebuf = self.linebuf + buf
145 | self.linebuf = ""
146 | for line in temp_linebuf.splitlines(True):
147 | # From the io.TextIOWrapper docs:
148 | # On output, if newline is None, any '\n' characters written
149 | # are translated to the system default line separator.
150 | # By default sys.stdout.write() expects '\n' newlines and then
151 | # translates them so this is still cross platform.
152 | if line[-1] == "\n":
153 | self.logger.log(self.log_level, line.rstrip())
154 | else:
155 | self.linebuf += line
156 |
157 | def flush(self):
158 | if self.linebuf != "":
159 | self.logger.log(self.log_level, self.linebuf.rstrip())
160 | self.linebuf = ""
161 |
162 |
163 | def disable_torch_init():
164 | """
165 | Disable the redundant torch default initialization to accelerate model creation.
166 | """
167 | import torch
168 |
169 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
170 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
171 |
172 |
173 | def violates_moderation(text):
174 | """
175 | Check whether the text violates OpenAI moderation API.
176 | """
177 | url = "https://api.openai.com/v1/moderations"
178 | headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
179 | text = text.replace("\n", "")
180 | data = "{" + '"input": ' + f'"{text}"' + "}"
181 | data = data.encode("utf-8")
182 | try:
183 | ret = requests.post(url, headers=headers, data=data, timeout=5)
184 | flagged = ret.json()["results"][0]["flagged"]
185 | except requests.exceptions.RequestException as e:
186 | print(f"######################### Moderation Error: {e} #########################")
187 | flagged = False
188 | except KeyError as e:
189 | print(f"######################### Moderation Error: {e} #########################")
190 | flagged = False
191 |
192 | return flagged
193 |
194 |
195 | def pretty_print_semaphore(semaphore):
196 | if semaphore is None:
197 | return "None"
198 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
199 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 240
3 |
4 | [build-system]
5 | requires = ["setuptools>=61.0"]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [project]
9 | name = "proxyv"
10 | version = "1.0.0"
11 | readme = "README.md"
12 | requires-python = ">=3.8"
13 | classifiers = [
14 | "Programming Language :: Python :: 3",
15 | "License :: OSI Approved :: Apache Software License",
16 | ]
17 |
18 | [project.optional-dependencies]
19 | standalone = [
20 | "shortuuid",
21 | "httpx==0.24.0",
22 | "einops",
23 | "ftfy",
24 | ]
25 |
26 |
27 | train = [
28 | "proxyv[standalone]",
29 | "numpy==1.26.4",
30 | "open_clip_torch",
31 | "fastapi",
32 | "markdown2[all]",
33 | "requests",
34 | "sentencepiece",
35 | "torch==2.1.2",
36 | "torchvision==0.16.2",
37 | "uvicorn==0.32.0",
38 | "wandb",
39 | "deepspeed==0.14.2",
40 | "peft==0.4.0",
41 | "accelerate==0.34.2",
42 | "tokenizers==0.19.1",
43 | "transformers==4.41.2",
44 | "bitsandbytes==0.41.0",
45 | "scikit-learn==1.2.2",
46 | "sentencepiece~=0.1.99",
47 | "einops==0.6.1",
48 | "einops-exts==0.0.4",
49 | "gradio_client==0.2.9",
50 | "urllib3<=2.0.0",
51 | "datasets==2.16.1",
52 | "pydantic==2.10.6",
53 | "timm",
54 | "triton==2.1.0",
55 | "hf_transfer",
56 | "opencv-python",
57 | "av",
58 | "decord",
59 | "tyro",
60 | "scipy",
61 | ]
62 |
63 | [tool.setuptools.packages.find]
64 | include = ["llava*"]
65 | exclude = [
66 | "assets*",
67 | "benchmark*",
68 | "docs",
69 | "dist*",
70 | "playground*",
71 | "scripts*",
72 | "train_scripts",
73 | "tests*",
74 | "checkpoints*",
75 | "project_checkpoints*",
76 | "debug_checkpoints*",
77 | "mlx_configs*",
78 | "wandb*",
79 | "notebooks*",
80 | ]
81 |
82 | [tool.wheel]
83 | exclude = [
84 | "assets*",
85 | "benchmark*",
86 | "docs",
87 | "dist*",
88 | "playground*",
89 | "scripts*",
90 | "train_scripts",
91 | "tests*",
92 | "checkpoints*",
93 | "project_checkpoints*",
94 | "debug_checkpoints*",
95 | "mlx_configs*",
96 | "wandb*",
97 | "notebooks*",
98 | ]
99 |
--------------------------------------------------------------------------------
/scripts/finetune/finetune_vicuna7b_baseline.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="lmsys/vicuna-7b-v1.5"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | RANK=${RANK:-0} # srun env node rank
24 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
25 | MASTER_PORT=${MASTER_PORT:-10086}
26 |
27 | PROMPT_VERSION="v1"
28 |
29 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_baseline"
30 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
31 |
32 | RUN_NAME="proxyv_vicuna7b_finetune_baseline"
33 | echo "RUN_NAME: ${RUN_NAME}"
34 |
35 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
36 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
37 |
38 | LOG_DIR=checkpoints/${RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=4
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --pretrain_mm_mlp_adapter="./checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \
55 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model" \
56 | --vision_tower ${VISION_MODEL_VERSION} \
57 | --mm_projector_type mlp2x_gelu \
58 | --mm_vision_select_layer -2 \
59 | --mm_use_im_start_end False \
60 | --mm_use_im_patch_token False \
61 | --group_by_modality_length True \
62 | --max_num_image_crops 5 \
63 | --per_crop_token_len 576 \
64 | --proxyv False \
65 | --image_aspect_ratio anyres \
66 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \
67 | --bf16 True \
68 | --run_name $RUN_NAME \
69 | --output_dir "./checkpoints/${RUN_NAME}" \
70 | --num_train_epochs 1 \
71 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
72 | --per_device_eval_batch_size 4 \
73 | --gradient_accumulation_steps $GRADIENT_ACC \
74 | --evaluation_strategy "no" \
75 | --save_strategy "steps" \
76 | --save_steps 500 \
77 | --save_total_limit 1 \
78 | --learning_rate 4e-5 \
79 | --weight_decay 0. \
80 | --warmup_ratio 0.03 \
81 | --lr_scheduler_type "cosine" \
82 | --logging_steps 1 \
83 | --tf32 True \
84 | --model_max_length 4096 \
85 | --gradient_checkpointing True \
86 | --dataloader_num_workers 16 \
87 | --lazy_preprocess True \
88 | --report_to wandb \
89 | --run_name $RUN_NAME \
90 | --torch_compile True \
91 | --torch_compile_backend "inductor" \
92 | --dataloader_drop_last True \
93 | --attn_implementation sdpa \
94 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/finetune/finetune_vicuna7b_proxyv.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="lmsys/vicuna-7b-v1.5"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | RANK=${RANK:-0} # srun env node rank
24 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
25 | MASTER_PORT=${MASTER_PORT:-10086}
26 |
27 | PROMPT_VERSION="v1"
28 |
29 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_proxyv_layer12_lr3e4"
30 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
31 |
32 | RUN_NAME="proxyv_vicuna7b_finetune_proxyv_layer12"
33 | echo "RUN_NAME: ${RUN_NAME}"
34 |
35 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
36 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
37 |
38 | LOG_DIR=checkpoints/${RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=4
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --pretrain_mm_mlp_adapter="./checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \
55 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model,mm_vision_mlp" \
56 | --vision_tower ${VISION_MODEL_VERSION} \
57 | --mm_projector_type mlp2x_gelu \
58 | --mm_vision_select_layer -2 \
59 | --mm_use_im_start_end False \
60 | --mm_use_im_patch_token False \
61 | --group_by_modality_length True \
62 | --max_num_image_crops 5 \
63 | --per_crop_token_len 576 \
64 | --proxyv_reduce_factor 4 \
65 | --proxyv True \
66 | --proxyv_start_layer 12 \
67 | --image_aspect_ratio anyres \
68 | --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \
69 | --mm_patch_merge_type spatial_unpad \
70 | --bf16 True \
71 | --run_name $RUN_NAME \
72 | --output_dir "./checkpoints/${RUN_NAME}" \
73 | --num_train_epochs 1 \
74 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
75 | --per_device_eval_batch_size 4 \
76 | --gradient_accumulation_steps $GRADIENT_ACC \
77 | --evaluation_strategy "no" \
78 | --save_strategy "steps" \
79 | --save_steps 100 \
80 | --save_total_limit 1 \
81 | --learning_rate 4e-5 \
82 | --weight_decay 0. \
83 | --warmup_ratio 0.03 \
84 | --lr_scheduler_type "cosine" \
85 | --logging_steps 1 \
86 | --tf32 True \
87 | --model_max_length 4096 \
88 | --gradient_checkpointing True \
89 | --dataloader_num_workers 16 \
90 | --lazy_preprocess True \
91 | --report_to wandb \
92 | --run_name $RUN_NAME \
93 | --torch_compile True \
94 | --torch_compile_backend "inductor" \
95 | --dataloader_drop_last True \
96 | --attn_implementation sdpa \
97 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_llama8b_baseline.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="meta-llama/Meta-Llama-3-8B-Instruct"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=llava_llama_3
34 |
35 | BASE_RUN_NAME="proxyv_llama3_8b_pretrain_baseline"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=8
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter" \
55 | --mm_vision_select_layer -2 \
56 | --max_num_image_crops 1 \
57 | --per_crop_token_len 576 \
58 | --proxyv False \
59 | --mm_projector_type mlp2x_gelu \
60 | --mm_use_im_start_end False \
61 | --mm_use_im_patch_token False \
62 | --image_aspect_ratio pad \
63 | --bf16 True \
64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
65 | --num_train_epochs 1 \
66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
67 | --per_device_eval_batch_size 4 \
68 | --gradient_accumulation_steps $GRADIENT_ACC \
69 | --evaluation_strategy "no" \
70 | --save_strategy "steps" \
71 | --save_total_limit 1 \
72 | --save_steps 1000 \
73 | --learning_rate 1e-3 \
74 | --weight_decay 0. \
75 | --warmup_ratio 0.03 \
76 | --lr_scheduler_type "cosine" \
77 | --logging_steps 1 \
78 | --tf32 True \
79 | --model_max_length 2048 \
80 | --gradient_checkpointing True \
81 | --dataloader_num_workers 4 \
82 | --lazy_preprocess True \
83 | --report_to wandb \
84 | --run_name $BASE_RUN_NAME \
85 | --attn_implementation sdpa \
86 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_llama8b_proxyv.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="meta-llama/Meta-Llama-3-8B-Instruct"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=llava_llama_3
34 |
35 | BASE_RUN_NAME="proxyv_llama3_8b_pretrain_proxyv_layer16_lr3e4"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=8
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \
55 | --mm_vision_mlp_lr 3e-4 \
56 | --mm_vision_select_layer -2 \
57 | --max_num_image_crops 1 \
58 | --per_crop_token_len 576 \
59 | --proxyv_reduce_factor 4 \
60 | --proxyv True \
61 | --proxyv_start_layer 16 \
62 | --mm_projector_type mlp2x_gelu \
63 | --mm_use_im_start_end False \
64 | --mm_use_im_patch_token False \
65 | --image_aspect_ratio pad \
66 | --bf16 True \
67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
68 | --num_train_epochs 1 \
69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
70 | --per_device_eval_batch_size 4 \
71 | --gradient_accumulation_steps $GRADIENT_ACC \
72 | --evaluation_strategy "no" \
73 | --save_strategy "steps" \
74 | --save_total_limit 1 \
75 | --save_steps 1000 \
76 | --learning_rate 1e-3 \
77 | --weight_decay 0. \
78 | --warmup_ratio 0.03 \
79 | --lr_scheduler_type "cosine" \
80 | --logging_steps 1 \
81 | --tf32 True \
82 | --model_max_length 2048 \
83 | --gradient_checkpointing True \
84 | --dataloader_num_workers 4 \
85 | --lazy_preprocess True \
86 | --report_to wandb \
87 | --run_name $BASE_RUN_NAME \
88 | --attn_implementation sdpa \
89 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_phi3b_baseline.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="microsoft/Phi-3-mini-4k-instruct"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=phi3_instruct
34 |
35 | BASE_RUN_NAME="proxyv_phi3_3b_pretrain_baseline"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=16
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter" \
55 | --mm_vision_select_layer -2 \
56 | --max_num_image_crops 1 \
57 | --per_crop_token_len 576 \
58 | --proxyv False \
59 | --mm_projector_type mlp2x_gelu \
60 | --mm_use_im_start_end False \
61 | --mm_use_im_patch_token False \
62 | --image_aspect_ratio pad \
63 | --bf16 True \
64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
65 | --num_train_epochs 1 \
66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
67 | --per_device_eval_batch_size 4 \
68 | --gradient_accumulation_steps $GRADIENT_ACC \
69 | --evaluation_strategy "no" \
70 | --save_strategy "steps" \
71 | --save_total_limit 1 \
72 | --save_steps 1000 \
73 | --learning_rate 1e-3 \
74 | --weight_decay 0. \
75 | --warmup_ratio 0.03 \
76 | --lr_scheduler_type "cosine" \
77 | --logging_steps 1 \
78 | --tf32 True \
79 | --model_max_length 2048 \
80 | --gradient_checkpointing True \
81 | --dataloader_num_workers 4 \
82 | --lazy_preprocess True \
83 | --report_to wandb \
84 | --run_name $BASE_RUN_NAME \
85 | --attn_implementation sdpa \
86 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_phi3b_proxyv.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="microsoft/Phi-3-mini-4k-instruct"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=phi3_instruct
34 |
35 | BASE_RUN_NAME="proxyv_phi3_3b_pretrain_proxyv_layer16_lr5e4"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=16
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \
55 | --mm_vision_mlp_lr 5e-4 \
56 | --mm_vision_select_layer -2 \
57 | --max_num_image_crops 1 \
58 | --per_crop_token_len 576 \
59 | --proxyv_reduce_factor 4 \
60 | --proxyv True \
61 | --proxyv_start_layer 16 \
62 | --mm_projector_type mlp2x_gelu \
63 | --mm_use_im_start_end False \
64 | --mm_use_im_patch_token False \
65 | --image_aspect_ratio pad \
66 | --bf16 True \
67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
68 | --num_train_epochs 1 \
69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
70 | --per_device_eval_batch_size 4 \
71 | --gradient_accumulation_steps $GRADIENT_ACC \
72 | --evaluation_strategy "no" \
73 | --save_strategy "steps" \
74 | --save_total_limit 1 \
75 | --save_steps 1000 \
76 | --learning_rate 1e-3 \
77 | --weight_decay 0. \
78 | --warmup_ratio 0.03 \
79 | --lr_scheduler_type "cosine" \
80 | --logging_steps 1 \
81 | --tf32 True \
82 | --model_max_length 2048 \
83 | --gradient_checkpointing True \
84 | --dataloader_num_workers 4 \
85 | --lazy_preprocess True \
86 | --report_to wandb \
87 | --run_name $BASE_RUN_NAME \
88 | --attn_implementation sdpa \
89 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_qwen7b_baseline.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="Qwen/Qwen2-7B-Instruct"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=qwen_1_5
34 |
35 | BASE_RUN_NAME="proxyv_qwen2_7b_pretrain_baseline"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=16
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter" \
55 | --mm_vision_select_layer -2 \
56 | --max_num_image_crops 1 \
57 | --per_crop_token_len 576 \
58 | --proxyv False \
59 | --mm_projector_type mlp2x_gelu \
60 | --mm_use_im_start_end False \
61 | --mm_use_im_patch_token False \
62 | --image_aspect_ratio pad \
63 | --bf16 True \
64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
65 | --num_train_epochs 1 \
66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
67 | --per_device_eval_batch_size 4 \
68 | --gradient_accumulation_steps $GRADIENT_ACC \
69 | --evaluation_strategy "no" \
70 | --save_strategy "steps" \
71 | --save_total_limit 1 \
72 | --save_steps 1000 \
73 | --learning_rate 1e-3 \
74 | --weight_decay 0. \
75 | --warmup_ratio 0.03 \
76 | --lr_scheduler_type "cosine" \
77 | --logging_steps 1 \
78 | --tf32 True \
79 | --model_max_length 2048 \
80 | --gradient_checkpointing True \
81 | --dataloader_num_workers 4 \
82 | --lazy_preprocess True \
83 | --report_to wandb \
84 | --run_name $BASE_RUN_NAME \
85 | --attn_implementation sdpa \
86 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_qwen7b_proxyv.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="Qwen/Qwen2-7B-Instruct"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=qwen_1_5
34 |
35 | BASE_RUN_NAME="proxyv_qwen2_7b_pretrain_proxyv_layer16_lr8e4"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=16
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \
55 | --mm_vision_mlp_lr 8e-4 \
56 | --mm_vision_select_layer -2 \
57 | --max_num_image_crops 1 \
58 | --per_crop_token_len 576 \
59 | --proxyv_reduce_factor 4 \
60 | --proxyv True \
61 | --proxyv_start_layer 16 \
62 | --mm_projector_type mlp2x_gelu \
63 | --mm_use_im_start_end False \
64 | --mm_use_im_patch_token False \
65 | --image_aspect_ratio pad \
66 | --bf16 True \
67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
68 | --num_train_epochs 1 \
69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
70 | --per_device_eval_batch_size 4 \
71 | --gradient_accumulation_steps $GRADIENT_ACC \
72 | --evaluation_strategy "no" \
73 | --save_strategy "steps" \
74 | --save_total_limit 1 \
75 | --save_steps 1000 \
76 | --learning_rate 1e-3 \
77 | --weight_decay 0. \
78 | --warmup_ratio 0.03 \
79 | --lr_scheduler_type "cosine" \
80 | --logging_steps 1 \
81 | --tf32 True \
82 | --model_max_length 2048 \
83 | --gradient_checkpointing True \
84 | --dataloader_num_workers 4 \
85 | --lazy_preprocess True \
86 | --report_to wandb \
87 | --run_name $BASE_RUN_NAME \
88 | --attn_implementation sdpa \
89 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_vicuna13b_baselin.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="lmsys/vicuna-13b-v1.5"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=v1
34 |
35 | BASE_RUN_NAME="proxyv_vicuna13b_pretrain_baseline"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=8
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter" \
55 | --mm_vision_select_layer -2 \
56 | --max_num_image_crops 1 \
57 | --per_crop_token_len 576 \
58 | --proxyv False \
59 | --mm_projector_type mlp2x_gelu \
60 | --mm_use_im_start_end False \
61 | --mm_use_im_patch_token False \
62 | --image_aspect_ratio pad \
63 | --bf16 True \
64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
65 | --num_train_epochs 1 \
66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
67 | --per_device_eval_batch_size 4 \
68 | --gradient_accumulation_steps $GRADIENT_ACC \
69 | --evaluation_strategy "no" \
70 | --save_strategy "steps" \
71 | --save_total_limit 1 \
72 | --save_steps 1000 \
73 | --learning_rate 1e-3 \
74 | --weight_decay 0. \
75 | --warmup_ratio 0.03 \
76 | --lr_scheduler_type "cosine" \
77 | --logging_steps 1 \
78 | --tf32 True \
79 | --model_max_length 2048 \
80 | --gradient_checkpointing True \
81 | --dataloader_num_workers 4 \
82 | --lazy_preprocess True \
83 | --report_to wandb \
84 | --run_name $BASE_RUN_NAME \
85 | --attn_implementation sdpa \
86 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_vicuna13b_proxyv.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="lmsys/vicuna-13b-v1.5"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=v1
34 |
35 | BASE_RUN_NAME="proxyv_vicuna13b_pretrain_proxyv_layer16_lr3e4"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=8
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \
55 | --mm_vision_mlp_lr 3e-4 \
56 | --mm_vision_select_layer -2 \
57 | --max_num_image_crops 1 \
58 | --per_crop_token_len 576 \
59 | --proxyv_reduce_factor 4 \
60 | --proxyv True \
61 | --proxyv_start_layer 16 \
62 | --mm_projector_type mlp2x_gelu \
63 | --mm_use_im_start_end False \
64 | --mm_use_im_patch_token False \
65 | --image_aspect_ratio pad \
66 | --bf16 True \
67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
68 | --num_train_epochs 1 \
69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
70 | --per_device_eval_batch_size 4 \
71 | --gradient_accumulation_steps $GRADIENT_ACC \
72 | --evaluation_strategy "no" \
73 | --save_strategy "steps" \
74 | --save_total_limit 1 \
75 | --save_steps 1000 \
76 | --learning_rate 1e-3 \
77 | --weight_decay 0. \
78 | --warmup_ratio 0.03 \
79 | --lr_scheduler_type "cosine" \
80 | --logging_steps 1 \
81 | --tf32 True \
82 | --model_max_length 2048 \
83 | --gradient_checkpointing True \
84 | --dataloader_num_workers 4 \
85 | --lazy_preprocess True \
86 | --report_to wandb \
87 | --run_name $BASE_RUN_NAME \
88 | --attn_implementation sdpa \
89 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_vicuna7b_baseline.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="lmsys/vicuna-7b-v1.5"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=v1
34 |
35 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_baseline"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=16
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter" \
55 | --mm_vision_select_layer -2 \
56 | --max_num_image_crops 1 \
57 | --per_crop_token_len 576 \
58 | --proxyv False \
59 | --mm_projector_type mlp2x_gelu \
60 | --mm_use_im_start_end False \
61 | --mm_use_im_patch_token False \
62 | --image_aspect_ratio pad \
63 | --bf16 True \
64 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
65 | --num_train_epochs 1 \
66 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
67 | --per_device_eval_batch_size 4 \
68 | --gradient_accumulation_steps $GRADIENT_ACC \
69 | --evaluation_strategy "no" \
70 | --save_strategy "steps" \
71 | --save_total_limit 1 \
72 | --save_steps 1000 \
73 | --learning_rate 1e-3 \
74 | --weight_decay 0. \
75 | --warmup_ratio 0.03 \
76 | --lr_scheduler_type "cosine" \
77 | --logging_steps 1 \
78 | --tf32 True \
79 | --model_max_length 2048 \
80 | --gradient_checkpointing True \
81 | --dataloader_num_workers 4 \
82 | --lazy_preprocess True \
83 | --report_to wandb \
84 | --run_name $BASE_RUN_NAME \
85 | --attn_implementation sdpa \
86 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/pretrain/pretrain_vicuna7b_proxyv.sh:
--------------------------------------------------------------------------------
1 | !/usr/bin/env bash
2 | set -x
3 | T=`date +%Y%m%d_%H%M%S`
4 | export OMP_NUM_THREADS=8
5 | export NCCL_IB_DISABLE=0
6 | export NCCL_IB_GID_INDEX=3
7 | export NCCL_SOCKET_IFNAME=eth0
8 | export NCCL_DEBUG=INFO
9 | export WANDB_RESUME="allow"
10 |
11 | export WANDB_API_KEY=""
12 | export WANDB_PROJECT="proxyv"
13 |
14 | LLM_VERSION="lmsys/vicuna-7b-v1.5"
15 | LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
16 | VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336"
17 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
18 |
19 | DATA_PATH="TODO"
20 | IMAGE_FOLDER="TODO"
21 |
22 | NUM_GPUS=8
23 | NNODES=1
24 | RANK=${RANK:-0} # srun env node rank
25 | MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
26 | MASTER_PORT=${MASTER_PORT:-10086}
27 |
28 | WORLD_SIZE=${WORLD_SIZE:-1} # srun env node num
29 | echo "nnodes=${WORLD_SIZE}, node_rank=${RANK}"
30 |
31 | ############### Pretrain ################
32 |
33 | PROMPT_VERSION=v1
34 |
35 | BASE_RUN_NAME="proxyv_vicuna7b_pretrain_proxyv_layer12_lr3e4"
36 | echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
37 |
38 | LOG_DIR=checkpoints/projectors/${BASE_RUN_NAME}/logs
39 | mkdir -p ${LOG_DIR}
40 | LOG_FILE=${LOG_DIR}/node${RANK}_${T}.log
41 |
42 | BATCH_SIZE=512 # 1 = 8
43 | PER_DEVICE_BATCH_SIZE=16
44 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / NUM_GPUS / WORLD_SIZE))
45 |
46 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${WORLD_SIZE}" --node_rank="${RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
47 | ./llava/train/train_mem.py \
48 | --deepspeed scripts/zero3.json \
49 | --model_name_or_path ${LLM_VERSION} \
50 | --version ${PROMPT_VERSION} \
51 | --data_path ${DATA_PATH} \
52 | --image_folder ${IMAGE_FOLDER} \
53 | --vision_tower ${VISION_MODEL_VERSION} \
54 | --mm_tunable_parts="mm_mlp_adapter,mm_vision_mlp" \
55 | --mm_vision_mlp_lr 3e-4 \
56 | --mm_vision_select_layer -2 \
57 | --max_num_image_crops 1 \
58 | --per_crop_token_len 576 \
59 | --proxyv_reduce_factor 4 \
60 | --proxyv True \
61 | --proxyv_start_layer 12 \
62 | --mm_projector_type mlp2x_gelu \
63 | --mm_use_im_start_end False \
64 | --mm_use_im_patch_token False \
65 | --image_aspect_ratio pad \
66 | --bf16 True \
67 | --output_dir ./checkpoints/projectors/${BASE_RUN_NAME} \
68 | --num_train_epochs 1 \
69 | --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
70 | --per_device_eval_batch_size 4 \
71 | --gradient_accumulation_steps $GRADIENT_ACC \
72 | --evaluation_strategy "no" \
73 | --save_strategy "steps" \
74 | --save_total_limit 1 \
75 | --save_steps 1000 \
76 | --learning_rate 1e-3 \
77 | --weight_decay 0. \
78 | --warmup_ratio 0.03 \
79 | --lr_scheduler_type "cosine" \
80 | --logging_steps 1 \
81 | --tf32 True \
82 | --model_max_length 2048 \
83 | --gradient_checkpointing True \
84 | --dataloader_num_workers 4 \
85 | --lazy_preprocess True \
86 | --report_to wandb \
87 | --run_name $BASE_RUN_NAME \
88 | --attn_implementation sdpa \
89 | 2>&1 | tee ${LOG_FILE}
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/scripts/zero2_fused_adamw.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": true,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/scripts/zero2_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "offload_optimizer": {
19 | "device": "cpu",
20 | "pin_memory": true
21 | },
22 | "offload_param": {
23 | "device": "cpu",
24 | "pin_memory": true
25 | },
26 | "overlap_comm": true,
27 | "contiguous_gradients": true,
28 | "sub_group_size": 1e9,
29 | "reduce_bucket_size": "auto"
30 | }
31 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "train_batch_size": "auto",
45 | "train_micro_batch_size_per_gpu": "auto",
46 | "steps_per_print": 1e5,
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/scripts/zero3pp.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 |
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "none",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "none",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "zero_quantized_weights": true,
36 | "zero_hpz_partition_size": 16,
37 | "zero_quantized_gradients": true,
38 | "sub_group_size": 1e9,
39 | "reduce_bucket_size": "auto",
40 | "stage3_prefetch_bucket_size": "auto",
41 | "stage3_param_persistence_threshold": "auto",
42 | "stage3_max_live_parameters": 1e9,
43 | "stage3_max_reuse_distance": 1e9,
44 | "stage3_gather_16bit_weights_on_model_save": true
45 | },
46 |
47 | "gradient_accumulation_steps": "auto",
48 | "gradient_clipping": "auto",
49 | "steps_per_print": 100,
50 | "train_batch_size": "auto",
51 | "train_micro_batch_size_per_gpu": "auto",
52 | "wall_clock_breakdown": false
53 | }
--------------------------------------------------------------------------------