├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_gemma.py │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mixtral.py │ │ ├── llava_mpt.py │ │ ├── llava_qwen.py │ │ ├── llava_qwen_moe.py │ │ └── modeling_llama.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── dev_eva_clip │ │ │ ├── eva_clip │ │ │ │ ├── __init__.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── constants.py │ │ │ │ ├── eva_vit_model.py │ │ │ │ ├── factory.py │ │ │ │ ├── hf_configs.py │ │ │ │ ├── hf_model.py │ │ │ │ ├── loss.py │ │ │ │ ├── model.py │ │ │ │ ├── model_configs │ │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ │ │ ├── modified_resnet.py │ │ │ │ ├── openai.py │ │ │ │ ├── pretrained.py │ │ │ │ ├── rope.py │ │ │ │ ├── timm_model.py │ │ │ │ ├── tokenizer.py │ │ │ │ ├── transform.py │ │ │ │ ├── transformer.py │ │ │ │ └── utils.py │ │ │ └── eva_vit.py │ │ ├── eva_clip │ │ │ ├── eva_clip_encoder.py │ │ │ ├── eva_clip_processors.py │ │ │ ├── eva_vit.py │ │ │ ├── factory.py │ │ │ └── model_configs │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ ├── hf_vision.py │ │ ├── imagebind.py │ │ ├── open_clip_encoder.py │ │ └── siglip_encoder.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── masked_drop.py │ │ ├── perceiver.py │ │ ├── qformer.py │ │ └── spatial_pool.py │ └── utils.py └── utils.py ├── nodes.py └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | # *.json 11 | # *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | # Editor 16 | .idea 17 | *.swp 18 | .vscode 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | llavavid 25 | 26 | checkpoints 27 | project_checkpoints 28 | debug_checkpoints 29 | playground/data 30 | playground/cc3m_llava34b_cap 31 | ckpts* 32 | 33 | .ipynb_checkpoints 34 | chunyl_scripts 35 | *.ipynb 36 | 37 | # DevContainer 38 | !.devcontainer/* 39 | 40 | # Demo 41 | serve_images/ 42 | notebooks/ 43 | logs 44 | scripts/dist_* 45 | logs/ 46 | submissions/ 47 | cn_scripts/ 48 | internal_project_checkpoints/ 49 | work_dirs 50 | scripts/i18n/* 51 | playground/.nfs028b000000010add00000001 52 | HIP 53 | playground/.nfs028b0000017bff2c00000012 54 | scripts/qwen 55 | scripts/vicuna 56 | scripts/mistral 57 | scripts/baseline_rep 58 | scripts/cn_boli01_hl 59 | scripts/cn_boli01_lf 60 | scripts/cn_lf 61 | scripts/cn_lq 62 | scripts/cn_yg 63 | scripts/cn_yg_hao 64 | scripts/eva_encoder 65 | scripts/i18n 66 | scripts/i18n_higher_res 67 | scripts/multi-images 68 | scratchpad 69 | build/ 70 | playground/*.json 71 | mlx_configs/ 72 | data_processing/ 73 | # demo/ 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WORK IN PROGRESS 2 | 3 | Unsure of the dependencies, the original was a huge list, but I didn't install single new one to my environment and it worked. 4 | 5 | ![image](https://github.com/user-attachments/assets/fcf83b77-6a80-47e7-aa29-aeefb0431beb) 6 | 7 | Original repo: 8 | https://github.com/LLaVA-VL/LLaVA-NeXT 9 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", 7 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", 8 | "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", 9 | # Add other models as needed 10 | } 11 | 12 | for model_name, model_classes in AVAILABLE_MODELS.items(): 13 | try: 14 | exec(f"from .language_model.{model_name} import {model_classes}") 15 | except Exception as e: 16 | print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}") 17 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from ...llava import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from ..model import * 11 | from ..model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaGemmaConfig(GemmaConfig): 31 | model_type = "llava_gemma" 32 | 33 | 34 | class LlavaGemmaModel(LlavaMetaModel, GemmaModel): 35 | config_class = LlavaGemmaConfig 36 | 37 | def __init__(self, config: GemmaConfig): 38 | super(LlavaGemmaModel, self).__init__(config) 39 | 40 | 41 | class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaGemmaConfig 43 | 44 | def __init__(self, config): 45 | super(GemmaForCausalLM, self).__init__(config) 46 | self.model = LlavaGemmaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | image_sizes: Optional[List[List[int]]] = None, 69 | return_dict: Optional[bool] = None, 70 | cache_position: Optional[torch.LongTensor] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 75 | 76 | return super().forward( 77 | input_ids=input_ids, 78 | attention_mask=attention_mask, 79 | position_ids=position_ids, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | labels=labels, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict, 87 | cache_position=cache_position, 88 | ) 89 | 90 | @torch.no_grad() 91 | def generate( 92 | self, 93 | inputs: Optional[torch.Tensor] = None, 94 | images: Optional[torch.Tensor] = None, 95 | image_sizes: Optional[torch.Tensor] = None, 96 | **kwargs, 97 | ) -> Union[GenerateOutput, torch.LongTensor]: 98 | position_ids = kwargs.pop("position_ids", None) 99 | attention_mask = kwargs.pop("attention_mask", None) 100 | if "inputs_embeds" in kwargs: 101 | raise NotImplementedError("`inputs_embeds` is not supported") 102 | 103 | if images is not None: 104 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 105 | else: 106 | inputs_embeds = self.get_model().embed_tokens(inputs) 107 | 108 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 109 | 110 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 111 | images = kwargs.pop("images", None) 112 | image_sizes = kwargs.pop("image_sizes", None) 113 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 114 | if images is not None: 115 | inputs["images"] = images 116 | if image_sizes is not None: 117 | inputs["image_sizes"] = image_sizes 118 | return inputs 119 | 120 | 121 | AutoConfig.register("llava_gemma", LlavaGemmaConfig) 122 | AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM) 123 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig 22 | 23 | from torch.nn import CrossEntropyLoss 24 | 25 | 26 | # , LlamaModel, LlamaForCausalLM, GenerationConfig 27 | # from .modeling_llama import LlamaModel, LlamaForCausalLM 28 | from transformers import LlamaModel, LlamaForCausalLM 29 | from transformers.modeling_outputs import CausalLMOutputWithPast 30 | from transformers.generation.utils import GenerateOutput 31 | 32 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 33 | 34 | 35 | class LlavaConfig(LlamaConfig): 36 | model_type = "llava_llama" 37 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 38 | max_new_tokens: int = 1024 39 | do_sample: bool = False 40 | top_p: Optional[float] = None 41 | # rope_scaling: Optional[dict] = {} 42 | 43 | 44 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 45 | config_class = LlavaConfig 46 | 47 | def __init__(self, config: LlamaConfig): 48 | super(LlavaLlamaModel, self).__init__(config) 49 | 50 | 51 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 52 | config_class = LlavaConfig 53 | 54 | def __init__(self, config): 55 | LlamaForCausalLM.__init__(self, config) 56 | 57 | # configure default generation settings 58 | config.model_type = "llava_llama" 59 | # config.rope_scaling = None 60 | 61 | self.model = LlavaLlamaModel(config) 62 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 63 | # Initialize weights and apply final processing 64 | self.post_init() 65 | 66 | def get_model(self): 67 | return self.model 68 | 69 | def forward( 70 | self, 71 | input_ids: torch.LongTensor = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | position_ids: Optional[torch.LongTensor] = None, 74 | past_key_values: Optional[List[torch.FloatTensor]] = None, 75 | inputs_embeds: Optional[torch.FloatTensor] = None, 76 | labels: Optional[torch.LongTensor] = None, 77 | use_cache: Optional[bool] = None, 78 | output_attentions: Optional[bool] = None, 79 | output_hidden_states: Optional[bool] = None, 80 | images: Optional[torch.FloatTensor] = None, 81 | image_sizes: Optional[List[List[int]]] = None, 82 | return_dict: Optional[bool] = None, 83 | modalities: Optional[List[str]] = ["image"], 84 | dpo_forward: Optional[bool] = None, 85 | cache_position=None, 86 | ) -> Union[Tuple, CausalLMOutputWithPast]: 87 | 88 | if inputs_embeds is None: 89 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 90 | 91 | if dpo_forward: 92 | outputs = self.model( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | position_ids=position_ids, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict, 102 | ) 103 | 104 | hidden_states = outputs[0] 105 | logits = self.lm_head(hidden_states) 106 | return logits, labels 107 | 108 | else: 109 | return super().forward( 110 | input_ids=input_ids, 111 | attention_mask=attention_mask, 112 | position_ids=position_ids, 113 | past_key_values=past_key_values, 114 | inputs_embeds=inputs_embeds, 115 | labels=labels, 116 | use_cache=use_cache, 117 | output_attentions=output_attentions, 118 | output_hidden_states=output_hidden_states, 119 | return_dict=return_dict, 120 | ) 121 | 122 | @torch.no_grad() 123 | def generate( 124 | self, 125 | inputs: Optional[torch.Tensor] = None, 126 | images: Optional[torch.Tensor] = None, 127 | image_sizes: Optional[torch.Tensor] = None, 128 | modalities: Optional[List[str]] = ["image"], 129 | **kwargs, 130 | ) -> Union[GenerateOutput, torch.LongTensor]: 131 | modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities 132 | position_ids = kwargs.pop("position_ids", None) 133 | attention_mask = kwargs.pop("attention_mask", None) 134 | if "inputs_embeds" in kwargs: 135 | raise NotImplementedError("`inputs_embeds` is not supported") 136 | 137 | if images is not None: 138 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 139 | else: 140 | inputs_embeds = self.get_model().embed_tokens(inputs) 141 | 142 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 145 | images = kwargs.pop("images", None) 146 | image_sizes = kwargs.pop("image_sizes", None) 147 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 148 | if images is not None: 149 | inputs["images"] = images 150 | if image_sizes is not None: 151 | inputs["image_sizes"] = image_sizes 152 | return inputs 153 | 154 | 155 | AutoConfig.register("llava_llama", LlavaConfig) 156 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 157 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMistralConfig(MistralConfig): 31 | model_type = "llava_mistral" 32 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 33 | max_new_tokens: int = 1024 34 | do_sample: bool = False 35 | top_p: Optional[float] = None 36 | 37 | 38 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 39 | config_class = LlavaMistralConfig 40 | 41 | def __init__(self, config: MistralConfig): 42 | super(LlavaMistralModel, self).__init__(config) 43 | 44 | 45 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMistralConfig 47 | 48 | def __init__(self, config): 49 | super(MistralForCausalLM, self).__init__(config) 50 | 51 | config.model_type = "llava_mistral" 52 | config.rope_scaling = None 53 | 54 | self.model = LlavaMistralModel(config) 55 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | image_sizes: Optional[List[List[int]]] = None, 75 | return_dict: Optional[bool] = None, 76 | cache_position=None, 77 | ) -> Union[Tuple, CausalLMOutputWithPast]: 78 | 79 | if inputs_embeds is None: 80 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 81 | 82 | return super().forward( 83 | input_ids=input_ids, 84 | attention_mask=attention_mask, 85 | position_ids=position_ids, 86 | past_key_values=past_key_values, 87 | inputs_embeds=inputs_embeds, 88 | labels=labels, 89 | use_cache=use_cache, 90 | output_attentions=output_attentions, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | 95 | @torch.no_grad() 96 | def generate( 97 | self, 98 | inputs: Optional[torch.Tensor] = None, 99 | images: Optional[torch.Tensor] = None, 100 | image_sizes: Optional[torch.Tensor] = None, 101 | **kwargs, 102 | ) -> Union[GenerateOutput, torch.LongTensor]: 103 | position_ids = kwargs.pop("position_ids", None) 104 | attention_mask = kwargs.pop("attention_mask", None) 105 | if "inputs_embeds" in kwargs: 106 | raise NotImplementedError("`inputs_embeds` is not supported") 107 | 108 | if images is not None: 109 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 110 | else: 111 | inputs_embeds = self.get_model().embed_tokens(inputs) 112 | 113 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 114 | 115 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 116 | images = kwargs.pop("images", None) 117 | image_sizes = kwargs.pop("image_sizes", None) 118 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 119 | if images is not None: 120 | inputs["images"] = images 121 | if image_sizes is not None: 122 | inputs["image_sizes"] = image_sizes 123 | return inputs 124 | 125 | 126 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 127 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 128 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMixtralConfig(MixtralConfig): 31 | model_type = "llava_mixtral" 32 | 33 | 34 | class LlavaMixtralModel(LlavaMetaModel, MixtralModel): 35 | config_class = LlavaMixtralConfig 36 | 37 | def __init__(self, config: MixtralConfig): 38 | super(LlavaMixtralModel, self).__init__(config) 39 | 40 | 41 | class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaMixtralConfig 43 | 44 | def __init__(self, config): 45 | super(MixtralForCausalLM, self).__init__(config) 46 | 47 | config.model_type = "llava_mixtral" 48 | config.rope_scaling = None 49 | self.model = LlavaMixtralModel(config) 50 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | modalities: Optional[List[str]] = ["image"], 72 | dpo_forward: Optional[bool] = None, 73 | cache_position=None, 74 | ) -> Union[Tuple, CausalLMOutputWithPast]: 75 | 76 | if inputs_embeds is None: 77 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 78 | 79 | if dpo_forward: 80 | outputs = self.model( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask, 83 | position_ids=position_ids, 84 | past_key_values=past_key_values, 85 | inputs_embeds=inputs_embeds, 86 | use_cache=use_cache, 87 | output_attentions=output_attentions, 88 | output_hidden_states=output_hidden_states, 89 | return_dict=return_dict, 90 | ) 91 | 92 | hidden_states = outputs[0] 93 | logits = self.lm_head(hidden_states) 94 | return logits, labels 95 | 96 | else: 97 | return super().forward( 98 | input_ids=input_ids, 99 | attention_mask=attention_mask, 100 | position_ids=position_ids, 101 | past_key_values=past_key_values, 102 | inputs_embeds=inputs_embeds, 103 | labels=labels, 104 | use_cache=use_cache, 105 | output_attentions=output_attentions, 106 | output_hidden_states=output_hidden_states, 107 | return_dict=return_dict, 108 | ) 109 | 110 | @torch.no_grad() 111 | def generate( 112 | self, 113 | inputs: Optional[torch.Tensor] = None, 114 | images: Optional[torch.Tensor] = None, 115 | image_sizes: Optional[torch.Tensor] = None, 116 | modalities: Optional[List[str]] = ["image"], 117 | **kwargs, 118 | ) -> Union[GenerateOutput, torch.LongTensor]: 119 | position_ids = kwargs.pop("position_ids", None) 120 | attention_mask = kwargs.pop("attention_mask", None) 121 | if "inputs_embeds" in kwargs: 122 | raise NotImplementedError("`inputs_embeds` is not supported") 123 | 124 | if images is not None: 125 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 126 | else: 127 | inputs_embeds = self.get_model().embed_tokens(inputs) 128 | 129 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 130 | 131 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 132 | images = kwargs.pop("images", None) 133 | image_sizes = kwargs.pop("image_sizes", None) 134 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 135 | if images is not None: 136 | inputs["images"] = images 137 | if image_sizes is not None: 138 | inputs["image_sizes"] = image_sizes 139 | return inputs 140 | 141 | 142 | AutoConfig.register("llava_mixtral", LlavaMixtralConfig) 143 | AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM) 144 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig 21 | from ...model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 22 | 23 | 24 | class LlavaMptConfig(MptConfig): 25 | model_type = "llava_mpt" 26 | 27 | 28 | class LlavaMptModel(LlavaMetaModel, MptModel): 29 | config_class = LlavaMptConfig 30 | 31 | def __init__(self, config: MptConfig): 32 | config.hidden_size = config.d_model 33 | super(LlavaMptModel, self).__init__(config) 34 | 35 | def embed_tokens(self, x): 36 | return self.wte(x) 37 | 38 | 39 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaMptConfig 41 | supports_gradient_checkpointing = True 42 | 43 | def __init__(self, config): 44 | super(MptForCausalLM, self).__init__(config) 45 | 46 | config.model_type = "llava_mpt" 47 | config.rope_scaling = None 48 | self.generation_config = GenerationConfig( 49 | temperature=0.0, 50 | max_new_tokens=1024, 51 | do_sample=False, 52 | top_p=None, 53 | ) 54 | 55 | self.transformer = LlavaMptModel(config) 56 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | 58 | # Initialize weights and apply final processing 59 | self.post_init() 60 | 61 | def get_model(self): 62 | return self.transformer 63 | 64 | def _set_gradient_checkpointing(self, module, value=False): 65 | if isinstance(module, LlavaMptModel): 66 | module.gradient_checkpointing = value 67 | 68 | def forward( 69 | self, 70 | input_ids: Optional[torch.LongTensor] = None, 71 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | inputs_embeds: Optional[torch.Tensor] = None, 74 | labels: Optional[torch.Tensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | cache_position=None, 80 | images=None, 81 | ): 82 | 83 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 84 | 85 | return super().forward( 86 | input_ids, 87 | past_key_values=past_key_values, 88 | attention_mask=attention_mask, 89 | inputs_embeds=inputs_embeds, 90 | labels=labels, 91 | use_cache=use_cache, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | ) 96 | 97 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 98 | images = kwargs.pop("images", None) 99 | _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 100 | _inputs["images"] = images 101 | return _inputs 102 | 103 | 104 | AutoConfig.register("llava_mpt", LlavaMptConfig) 105 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 106 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Hao Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union, Dict 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import CrossEntropyLoss 20 | 21 | import transformers 22 | from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM 30 | 31 | # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel 32 | # from .qwen.configuration_qwen import QWenConfig 33 | 34 | 35 | class LlavaQwenConfig(Qwen2Config): 36 | model_type = "llava_qwen" 37 | 38 | 39 | class LlavaQwenModel(LlavaMetaModel, Qwen2Model): 40 | config_class = LlavaQwenConfig 41 | 42 | def __init__(self, config: Qwen2Config): 43 | super(LlavaQwenModel, self).__init__(config) 44 | 45 | 46 | class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): 47 | config_class = LlavaQwenConfig 48 | 49 | def __init__(self, config): 50 | # super(Qwen2ForCausalLM, self).__init__(config) 51 | Qwen2ForCausalLM.__init__(self, config) 52 | config.model_type = "llava_qwen" 53 | config.rope_scaling = None 54 | 55 | self.model = LlavaQwenModel(config) 56 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | # Initialize weights and apply final processing 58 | self.post_init() 59 | 60 | def get_model(self): 61 | return self.model 62 | 63 | def forward( 64 | self, 65 | input_ids: torch.LongTensor = None, 66 | attention_mask: Optional[torch.Tensor] = None, 67 | position_ids: Optional[torch.LongTensor] = None, 68 | past_key_values: Optional[List[torch.FloatTensor]] = None, 69 | inputs_embeds: Optional[torch.FloatTensor] = None, 70 | labels: Optional[torch.LongTensor] = None, 71 | use_cache: Optional[bool] = None, 72 | output_attentions: Optional[bool] = None, 73 | output_hidden_states: Optional[bool] = None, 74 | images: Optional[torch.FloatTensor] = None, 75 | image_sizes: Optional[List[List[int]]] = None, 76 | return_dict: Optional[bool] = None, 77 | modalities: Optional[List[str]] = ["image"], 78 | dpo_forward: Optional[bool] = False, 79 | cache_position=None, 80 | ) -> Union[Tuple, CausalLMOutputWithPast]: 81 | 82 | if inputs_embeds is None: 83 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 84 | 85 | if dpo_forward: 86 | outputs = self.model( 87 | input_ids=input_ids, 88 | attention_mask=attention_mask, 89 | position_ids=position_ids, 90 | past_key_values=past_key_values, 91 | inputs_embeds=inputs_embeds, 92 | use_cache=use_cache, 93 | output_attentions=output_attentions, 94 | output_hidden_states=output_hidden_states, 95 | return_dict=return_dict, 96 | ) 97 | 98 | hidden_states = outputs[0] 99 | logits = self.lm_head(hidden_states) 100 | return logits, labels 101 | 102 | else: 103 | return super().forward( 104 | input_ids=input_ids, 105 | attention_mask=attention_mask, 106 | position_ids=position_ids, 107 | past_key_values=past_key_values, 108 | inputs_embeds=inputs_embeds, 109 | labels=labels, 110 | use_cache=use_cache, 111 | output_attentions=output_attentions, 112 | output_hidden_states=output_hidden_states, 113 | return_dict=return_dict, 114 | ) 115 | 116 | @torch.no_grad() 117 | def generate( 118 | self, 119 | inputs: Optional[torch.Tensor] = None, 120 | images: Optional[torch.Tensor] = None, 121 | image_sizes: Optional[torch.Tensor] = None, 122 | modalities: Optional[List[str]] = ["image"], 123 | **kwargs, 124 | ) -> Union[GenerateOutput, torch.LongTensor]: 125 | position_ids = kwargs.pop("position_ids", None) 126 | attention_mask = kwargs.pop("attention_mask", None) 127 | if "inputs_embeds" in kwargs: 128 | raise NotImplementedError("`inputs_embeds` is not supported") 129 | 130 | if images is not None: 131 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 132 | else: 133 | inputs_embeds = self.get_model().embed_tokens(inputs) 134 | 135 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 136 | 137 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 138 | images = kwargs.pop("images", None) 139 | image_sizes = kwargs.pop("image_sizes", None) 140 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 141 | if images is not None: 142 | inputs["images"] = images 143 | if image_sizes is not None: 144 | inputs["image_sizes"] = image_sizes 145 | return inputs 146 | 147 | 148 | AutoConfig.register("llava_qwen", LlavaQwenConfig) 149 | AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM) 150 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_qwen_moe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Hao Zhang 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union, Dict 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import CrossEntropyLoss 20 | 21 | import transformers 22 | from transformers import AutoConfig, AutoModelForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | # from ...constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | from transformers import Qwen2MoeConfig, Qwen2MoeModel, Qwen2MoeForCausalLM 30 | 31 | # from .qwen.modeling_qwen import QWenLMHeadModel, QWenModel 32 | # from .qwen.configuration_qwen import QWenConfig 33 | 34 | 35 | class LlavaQwenMoeConfig(Qwen2MoeConfig): 36 | model_type = "llava_qwen_moe" 37 | 38 | 39 | class LlavaQwenMoeModel(LlavaMetaModel, Qwen2MoeModel): 40 | config_class = LlavaQwenMoeConfig 41 | 42 | def __init__(self, config: Qwen2MoeConfig): 43 | super(LlavaQwenMoeModel, self).__init__(config) 44 | 45 | 46 | class LlavaQwenMoeForCausalLM(Qwen2MoeForCausalLM, LlavaMetaForCausalLM): 47 | config_class = LlavaQwenMoeConfig 48 | 49 | def __init__(self, config): 50 | # super(Qwen2MoeForCausalLM, self).__init__(config) 51 | Qwen2MoeForCausalLM.__init__(self, config) 52 | config.model_type = "llava_qwen_moe" 53 | config.rope_scaling = None 54 | 55 | self.model = LlavaQwenMoeModel(config) 56 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | # Initialize weights and apply final processing 58 | self.post_init() 59 | 60 | def get_model(self): 61 | return self.model 62 | 63 | def forward( 64 | self, 65 | input_ids: torch.LongTensor = None, 66 | attention_mask: Optional[torch.Tensor] = None, 67 | position_ids: Optional[torch.LongTensor] = None, 68 | past_key_values: Optional[List[torch.FloatTensor]] = None, 69 | inputs_embeds: Optional[torch.FloatTensor] = None, 70 | labels: Optional[torch.LongTensor] = None, 71 | use_cache: Optional[bool] = None, 72 | output_attentions: Optional[bool] = None, 73 | output_hidden_states: Optional[bool] = None, 74 | images: Optional[torch.FloatTensor] = None, 75 | image_sizes: Optional[List[List[int]]] = None, 76 | return_dict: Optional[bool] = None, 77 | modalities: Optional[List[str]] = ["image"], 78 | dpo_forward: Optional[bool] = False, 79 | cache_position=None, 80 | ) -> Union[Tuple, CausalLMOutputWithPast]: 81 | 82 | if inputs_embeds is None: 83 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 84 | 85 | if dpo_forward: 86 | outputs = self.model( 87 | input_ids=input_ids, 88 | attention_mask=attention_mask, 89 | position_ids=position_ids, 90 | past_key_values=past_key_values, 91 | inputs_embeds=inputs_embeds, 92 | use_cache=use_cache, 93 | output_attentions=output_attentions, 94 | output_hidden_states=output_hidden_states, 95 | return_dict=return_dict, 96 | ) 97 | 98 | hidden_states = outputs[0] 99 | logits = self.lm_head(hidden_states) 100 | return logits, labels 101 | 102 | else: 103 | return super().forward( 104 | input_ids=input_ids, 105 | attention_mask=attention_mask, 106 | position_ids=position_ids, 107 | past_key_values=past_key_values, 108 | inputs_embeds=inputs_embeds, 109 | labels=labels, 110 | use_cache=use_cache, 111 | output_attentions=output_attentions, 112 | output_hidden_states=output_hidden_states, 113 | return_dict=return_dict, 114 | ) 115 | 116 | @torch.no_grad() 117 | def generate( 118 | self, 119 | inputs: Optional[torch.Tensor] = None, 120 | images: Optional[torch.Tensor] = None, 121 | image_sizes: Optional[torch.Tensor] = None, 122 | modalities: Optional[List[str]] = ["image"], 123 | **kwargs, 124 | ) -> Union[GenerateOutput, torch.LongTensor]: 125 | position_ids = kwargs.pop("position_ids", None) 126 | attention_mask = kwargs.pop("attention_mask", None) 127 | if "inputs_embeds" in kwargs: 128 | raise NotImplementedError("`inputs_embeds` is not supported") 129 | 130 | if images is not None: 131 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 132 | else: 133 | inputs_embeds = self.get_model().embed_tokens(inputs) 134 | 135 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 136 | 137 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 138 | images = kwargs.pop("images", None) 139 | image_sizes = kwargs.pop("image_sizes", None) 140 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 141 | if images is not None: 142 | inputs["images"] = images 143 | if image_sizes is not None: 144 | inputs["image_sizes"] = image_sizes 145 | return inputs 146 | 147 | 148 | AutoConfig.register("llava_qwen_moe", LlavaQwenMoeConfig) 149 | AutoModelForCausalLM.register(LlavaQwenMoeConfig, LlavaQwenMoeForCausalLM) 150 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from .utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 34 | 35 | raise ValueError(f"Unknown vision tower: {vision_tower}") 36 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ...utils import rank0_print 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | try: 7 | from s2wrapper import forward as multiscale_forward 8 | except: 9 | pass 10 | 11 | 12 | class CLIPVisionTower(nn.Module): 13 | def __init__(self, vision_tower, args, delay_load=False): 14 | super().__init__() 15 | 16 | self.is_loaded = False 17 | 18 | self.vision_tower_name = vision_tower 19 | self.select_layer = args.mm_vision_select_layer 20 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 21 | 22 | if not delay_load: 23 | rank0_print(f"Loading vision tower: {vision_tower}") 24 | self.load_model() 25 | elif getattr(args, "unfreeze_mm_vision_tower", False): 26 | # TODO: better detector is needed. 27 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 28 | self.load_model() 29 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 30 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 31 | self.load_model() 32 | else: 33 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 34 | 35 | def load_model(self, device_map=None): 36 | if self.is_loaded: 37 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) 38 | return 39 | 40 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 41 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 42 | self.vision_tower.requires_grad_(False) 43 | 44 | self.is_loaded = True 45 | 46 | def feature_select(self, image_forward_outs): 47 | select_feature_type = self.select_feature 48 | 49 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 50 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 51 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 52 | select_feature_type = select_feature_type.replace("slicefour_", "") 53 | elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 54 | select_layers = [-2, -5, -8, -11, 6] 55 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) 56 | select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 57 | else: 58 | image_features = image_forward_outs.hidden_states[self.select_layer] 59 | 60 | if select_feature_type == "patch": 61 | image_features = image_features[:, 1:] 62 | elif select_feature_type == "cls_patch": 63 | image_features = image_features 64 | else: 65 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 66 | return image_features 67 | 68 | def forward(self, images): 69 | if type(images) is list: 70 | image_features = [] 71 | for image in images: 72 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 73 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 74 | image_features.append(image_feature) 75 | else: 76 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 77 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 78 | 79 | return image_features 80 | 81 | @property 82 | def dummy_feature(self): 83 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 84 | 85 | @property 86 | def dtype(self): 87 | return self.vision_tower.dtype 88 | 89 | @property 90 | def device(self): 91 | return self.vision_tower.device 92 | 93 | @property 94 | def config(self): 95 | if self.is_loaded: 96 | return self.vision_tower.config 97 | else: 98 | return self.cfg_only 99 | 100 | @property 101 | def hidden_size(self): 102 | _hidden_size = self.config.hidden_size 103 | if "slicefour" in self.select_feature: 104 | _hidden_size *= 4 105 | if "slice_m25811_f6" in self.select_feature: 106 | _hidden_size *= 5 107 | return _hidden_size 108 | 109 | @property 110 | def num_patches_per_side(self): 111 | return self.config.image_size // self.config.patch_size 112 | 113 | @property 114 | def num_patches(self): 115 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 116 | if "cls_patch" in self.select_feature: 117 | _num_patches += 1 118 | return _num_patches 119 | 120 | @property 121 | def image_size(self): 122 | return self.config.image_size 123 | 124 | 125 | class CLIPVisionTowerS2(CLIPVisionTower): 126 | def __init__(self, vision_tower, args, delay_load=False): 127 | 128 | self.s2_scales = getattr(args, "s2_scales", "336,672,1008") 129 | self.s2_scales = list(map(int, self.s2_scales.split(","))) 130 | self.s2_scales.sort() 131 | self.s2_split_size = self.s2_scales[0] 132 | self.s2_image_size = self.s2_scales[-1] 133 | 134 | super().__init__(vision_tower, args, delay_load) 135 | 136 | # change resize/crop size in preprocessing to the largest image size in s2_scale 137 | if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): 138 | self.image_processor.size["shortest_edge"] = self.s2_image_size 139 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size 140 | 141 | def load_model(self, device_map=None): 142 | if self.is_loaded: 143 | rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) 144 | return 145 | 146 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 147 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 148 | self.vision_tower.requires_grad_(False) 149 | 150 | self.image_processor.size["shortest_edge"] = self.s2_image_size 151 | self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size 152 | 153 | self.is_loaded = True 154 | 155 | def forward_feature(self, images): 156 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 157 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 158 | return image_features 159 | 160 | def forward(self, images): 161 | if type(images) is list: 162 | image_features = [] 163 | for image in images: 164 | image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) 165 | image_features.append(image_feature) 166 | else: 167 | image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) 168 | 169 | return image_features 170 | 171 | @property 172 | def hidden_size(self): 173 | return self.config.hidden_size * len(self.s2_scales) 174 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 8 | from .tokenizer import SimpleTokenizer, tokenize 9 | from .transform import image_transform 10 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LLaVA-OneVision/7feecc57a10f4b15fc32c0c54e0a6fba5e653606/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings", 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings", 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens", 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings", 54 | }, 55 | "pooler": "mean_pooler", 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch import TensorType 12 | 13 | try: 14 | import transformers 15 | from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig 16 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions 17 | except ImportError as e: 18 | transformers = None 19 | 20 | class BaseModelOutput: 21 | pass 22 | 23 | class PretrainedConfig: 24 | pass 25 | 26 | 27 | from .hf_configs import arch_dict 28 | 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r"(? TensorType: 135 | # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) 136 | # attn_mask = (x != self.config.pad_token_id).long() 137 | # out = self.transformer( 138 | # input_ids=x, 139 | # attention_mask=attn_mask, 140 | # encoder_hidden_states = image_embeds, 141 | # encoder_attention_mask = image_atts, 142 | # ) 143 | # pooled_out = self.pooler(out, attn_mask) 144 | 145 | # return self.itm_proj(pooled_out) 146 | 147 | def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): 148 | if masked_indices is None: 149 | masked_indices = torch.bernoulli(probability_matrix).bool() 150 | 151 | masked_indices[input_ids == self.tokenizer.pad_token_id] = False 152 | masked_indices[input_ids == self.tokenizer.cls_token_id] = False 153 | 154 | if targets is not None: 155 | targets[~masked_indices] = -100 # We only compute loss on masked tokens 156 | 157 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 158 | indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices 159 | input_ids[indices_replaced] = self.tokenizer.mask_token_id 160 | 161 | # 10% of the time, we replace masked input tokens with random word 162 | indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced 163 | random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) 164 | input_ids[indices_random] = random_words[indices_random] 165 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 166 | 167 | if targets is not None: 168 | return input_ids, targets 169 | else: 170 | return input_ids 171 | 172 | def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): 173 | labels = input_ids.clone() 174 | attn_mask = (input_ids != self.config.pad_token_id).long() 175 | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device) 176 | vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) 177 | probability_matrix = torch.full(labels.shape, mlm_probability) 178 | input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix) 179 | mlm_output = self.transformer( 180 | input_ids, 181 | attention_mask=attn_mask, 182 | encoder_hidden_states=image_embeds, 183 | encoder_attention_mask=image_atts, 184 | return_dict=True, 185 | labels=labels, 186 | ) 187 | return mlm_output.loss 188 | # mlm_output = self.transformer(input_ids, 189 | # attention_mask = attn_mask, 190 | # encoder_hidden_states = image_embeds, 191 | # encoder_attention_mask = image_atts, 192 | # return_dict = True, 193 | # ).last_hidden_state 194 | # logits = self.mlm_proj(mlm_output) 195 | 196 | # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) 197 | # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) 198 | # labels = labels[:, 1:].contiguous().view(-1) 199 | 200 | # mlm_loss = F.cross_entropy( 201 | # logits, 202 | # labels, 203 | # # label_smoothing=0.1, 204 | # ) 205 | # return mlm_loss 206 | 207 | def forward(self, x: TensorType) -> TensorType: 208 | attn_mask = (x != self.config.pad_token_id).long() 209 | out = self.transformer(input_ids=x, attention_mask=attn_mask) 210 | pooled_out = self.pooler(out, attn_mask) 211 | 212 | return self.proj(pooled_out) 213 | 214 | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): 215 | if not unlocked_layers: # full freezing 216 | for n, p in self.transformer.named_parameters(): 217 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False 218 | return 219 | 220 | encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer 221 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) 222 | print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") 223 | embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) 224 | modules = [embeddings, *layer_list][:-unlocked_layers] 225 | # freeze layers 226 | for module in modules: 227 | for n, p in module.named_parameters(): 228 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False 229 | 230 | @torch.jit.ignore 231 | def set_grad_checkpointing(self, enable=True): 232 | self.transformer.gradient_checkpointing_enable() 233 | 234 | def get_num_layers(self): 235 | encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer 236 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) 237 | return len(layer_list) 238 | 239 | def init_parameters(self): 240 | pass 241 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | 10 | has_distributed = True 11 | except ImportError: 12 | has_distributed = False 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy 20 | 21 | 22 | def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): 23 | assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." 24 | if use_horovod: 25 | assert hvd is not None, "Please install horovod" 26 | if gather_with_grad: 27 | all_image_features = hvd.allgather(image_features) 28 | all_text_features = hvd.allgather(text_features) 29 | else: 30 | with torch.no_grad(): 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | if not local_loss: 34 | # ensure grads for local rank when all_* features don't have a gradient 35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 37 | gathered_image_features[rank] = image_features 38 | gathered_text_features[rank] = text_features 39 | all_image_features = torch.cat(gathered_image_features, dim=0) 40 | all_text_features = torch.cat(gathered_text_features, dim=0) 41 | else: 42 | # We gather tensors from all gpus 43 | if gather_with_grad: 44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 46 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 47 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 48 | else: 49 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 50 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 51 | dist.all_gather(gathered_image_features, image_features) 52 | dist.all_gather(gathered_text_features, text_features) 53 | if not local_loss: 54 | # ensure grads for local rank when all_* features don't have a gradient 55 | gathered_image_features[rank] = image_features 56 | gathered_text_features[rank] = text_features 57 | all_image_features = torch.cat(gathered_image_features, dim=0) 58 | all_text_features = torch.cat(gathered_text_features, dim=0) 59 | 60 | return all_image_features, all_text_features 61 | 62 | 63 | class ClipLoss(nn.Module): 64 | 65 | def __init__( 66 | self, 67 | local_loss=False, 68 | gather_with_grad=False, 69 | cache_labels=False, 70 | rank=0, 71 | world_size=1, 72 | use_horovod=False, 73 | smoothing=0.0, 74 | ): 75 | super().__init__() 76 | self.local_loss = local_loss 77 | self.gather_with_grad = gather_with_grad 78 | self.cache_labels = cache_labels 79 | self.rank = rank 80 | self.world_size = world_size 81 | self.use_horovod = use_horovod 82 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale=1.0): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 92 | 93 | if self.local_loss: 94 | logits_per_image = logit_scale * image_features @ all_text_features.T 95 | logits_per_text = logit_scale * text_features @ all_image_features.T 96 | else: 97 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 98 | logits_per_text = logits_per_image.T 99 | else: 100 | logits_per_image = logit_scale * image_features @ text_features.T 101 | logits_per_text = logit_scale * text_features @ image_features.T 102 | # calculated ground-truth and cache if enabled 103 | num_logits = logits_per_image.shape[0] 104 | if self.prev_num_logits != num_logits or device not in self.labels: 105 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 106 | if self.world_size > 1 and self.local_loss: 107 | labels = labels + num_logits * self.rank 108 | if self.cache_labels: 109 | self.labels[device] = labels 110 | self.prev_num_logits = num_logits 111 | else: 112 | labels = self.labels[device] 113 | 114 | if self.label_smoothing_cross_entropy: 115 | total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 116 | else: 117 | total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 118 | 119 | acc = None 120 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 121 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 122 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 123 | return total_loss, acc 124 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))])) 37 | 38 | def forward(self, x: torch.Tensor): 39 | identity = x 40 | 41 | out = self.act1(self.bn1(self.conv1(x))) 42 | out = self.act2(self.bn2(self.conv2(out))) 43 | out = self.avgpool(out) 44 | out = self.bn3(self.conv3(out)) 45 | 46 | if self.downsample is not None: 47 | identity = self.downsample(x) 48 | 49 | out += identity 50 | out = self.act3(out) 51 | return out 52 | 53 | 54 | class AttentionPool2d(nn.Module): 55 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 56 | super().__init__() 57 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) 58 | self.k_proj = nn.Linear(embed_dim, embed_dim) 59 | self.q_proj = nn.Linear(embed_dim, embed_dim) 60 | self.v_proj = nn.Linear(embed_dim, embed_dim) 61 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 62 | self.num_heads = num_heads 63 | 64 | def forward(self, x): 65 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 66 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 67 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 68 | x, _ = F.multi_head_attention_forward( 69 | query=x, 70 | key=x, 71 | value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0.0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False, 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.image_size = image_size 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.act1 = nn.ReLU(inplace=True) 110 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 111 | self.bn2 = nn.BatchNorm2d(width // 2) 112 | self.act2 = nn.ReLU(inplace=True) 113 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 114 | self.bn3 = nn.BatchNorm2d(width) 115 | self.act3 = nn.ReLU(inplace=True) 116 | self.avgpool = nn.AvgPool2d(2) 117 | 118 | # residual layers 119 | self._inplanes = width # this is a *mutable* variable used during construction 120 | self.layer1 = self._make_layer(width, layers[0]) 121 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 122 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 123 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 124 | 125 | embed_dim = width * 32 # the ResNet feature dimension 126 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 127 | 128 | self.init_parameters() 129 | 130 | def _make_layer(self, planes, blocks, stride=1): 131 | layers = [Bottleneck(self._inplanes, planes, stride)] 132 | 133 | self._inplanes = planes * Bottleneck.expansion 134 | for _ in range(1, blocks): 135 | layers.append(Bottleneck(self._inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def init_parameters(self): 140 | if self.attnpool is not None: 141 | std = self.attnpool.c_proj.in_features**-0.5 142 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 143 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 144 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 146 | 147 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 148 | for name, param in resnet_block.named_parameters(): 149 | if name.endswith("bn3.weight"): 150 | nn.init.zeros_(param) 151 | 152 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 153 | assert unlocked_groups == 0, "partial locking not currently supported for this model" 154 | for param in self.parameters(): 155 | param.requires_grad = False 156 | if freeze_bn_stats: 157 | freeze_batch_norm_2d(self) 158 | 159 | @torch.jit.ignore 160 | def set_grad_checkpointing(self, enable=True): 161 | # FIXME support for non-transformer 162 | pass 163 | 164 | def stem(self, x): 165 | x = self.act1(self.bn1(self.conv1(x))) 166 | x = self.act2(self.bn2(self.conv2(x))) 167 | x = self.act3(self.bn3(self.conv3(x))) 168 | x = self.avgpool(x) 169 | return x 170 | 171 | def forward(self, x): 172 | x = self.stem(x) 173 | x = self.layer1(x) 174 | x = self.layer2(x) 175 | x = self.layer3(x) 176 | x = self.layer4(x) 177 | x = self.attnpool(x) 178 | 179 | return x 180 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag("openai") 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = "fp32" if device == "cpu" else "fp16" 56 | 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith("amp") or precision == "fp32": 87 | model.float() 88 | elif precision == "bf16": 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == "fp32": 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Dict, Union 6 | 7 | from tqdm import tqdm 8 | 9 | try: 10 | from huggingface_hub import hf_hub_download 11 | 12 | _has_hf_hub = True 13 | except ImportError: 14 | hf_hub_download = None 15 | _has_hf_hub = False 16 | 17 | 18 | def _pcfg(url="", hf_hub="", filename="", mean=None, std=None): 19 | return dict( 20 | url=url, 21 | hf_hub=hf_hub, 22 | mean=mean, 23 | std=std, 24 | ) 25 | 26 | 27 | _VITB32 = dict( 28 | openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 29 | laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 30 | laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 31 | laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 32 | laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"), 33 | ) 34 | 35 | _VITB32_quickgelu = dict( 36 | openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 37 | laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 38 | laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 39 | ) 40 | 41 | _VITB16 = dict( 42 | openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 43 | laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 44 | laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 45 | laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"), 46 | ) 47 | 48 | _EVAB16 = dict( 49 | eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"), 50 | eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"), 51 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"), 52 | eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"), 53 | ) 54 | 55 | _VITB16_PLUS_240 = dict( 56 | laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), 57 | laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 62 | laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 63 | laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 64 | laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 65 | ) 66 | 67 | _EVAL14 = dict( 68 | eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"), 69 | eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"), 70 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"), 71 | eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"), 72 | ) 73 | 74 | _VITL14_336 = dict( 75 | openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 76 | ) 77 | 78 | _EVAL14_336 = dict( 79 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"), 80 | eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"), 81 | eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"), 82 | eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"), 83 | ) 84 | 85 | _VITH14 = dict( 86 | laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"), 87 | ) 88 | 89 | _VITg14 = dict( 90 | laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"), 91 | laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"), 92 | ) 93 | 94 | _EVAg14 = dict( 95 | eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"), 96 | eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"), 97 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"), 98 | eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"), 99 | ) 100 | 101 | _EVAg14_PLUS = dict( 102 | eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"), 103 | eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"), 104 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"), 105 | eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"), 106 | ) 107 | 108 | _VITbigG14 = dict( 109 | laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"), 110 | ) 111 | 112 | _EVAbigE14 = dict( 113 | eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), 114 | eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), 115 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"), 116 | eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"), 117 | ) 118 | 119 | _EVAbigE14_PLUS = dict( 120 | eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), 121 | eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), 122 | eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"), 123 | eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"), 124 | ) 125 | 126 | _EVA_8B = dict( 127 | eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"), 128 | eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"), 129 | ) 130 | 131 | _EVA_8B_PLUS = dict( 132 | eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"), 133 | ) 134 | 135 | 136 | _PRETRAINED = { 137 | # "ViT-B-32": _VITB32, 138 | "OpenaiCLIP-B-32": _VITB32, 139 | "OpenCLIP-B-32": _VITB32, 140 | # "ViT-B-32-quickgelu": _VITB32_quickgelu, 141 | "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, 142 | "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, 143 | # "ViT-B-16": _VITB16, 144 | "OpenaiCLIP-B-16": _VITB16, 145 | "OpenCLIP-B-16": _VITB16, 146 | "EVA02-B-16": _EVAB16, 147 | "EVA02-CLIP-B-16": _EVAB16, 148 | # "ViT-B-16-plus-240": _VITB16_PLUS_240, 149 | "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, 150 | # "ViT-L-14": _VITL14, 151 | "OpenaiCLIP-L-14": _VITL14, 152 | "OpenCLIP-L-14": _VITL14, 153 | "EVA02-L-14": _EVAL14, 154 | "EVA02-CLIP-L-14": _EVAL14, 155 | # "ViT-L-14-336": _VITL14_336, 156 | "OpenaiCLIP-L-14-336": _VITL14_336, 157 | "EVA02-CLIP-L-14-336": _EVAL14_336, 158 | # "ViT-H-14": _VITH14, 159 | # "ViT-g-14": _VITg14, 160 | "OpenCLIP-H-14": _VITH14, 161 | "OpenCLIP-g-14": _VITg14, 162 | "EVA01-CLIP-g-14": _EVAg14, 163 | "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, 164 | # "ViT-bigG-14": _VITbigG14, 165 | "OpenCLIP-bigG-14": _VITbigG14, 166 | "EVA02-CLIP-bigE-14": _EVAbigE14, 167 | "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, 168 | "EVA-CLIP-8B": _EVA_8B, 169 | "EVA-CLIP-8B-448": _EVA_8B_PLUS, 170 | "EVA-CLIP-8B-plus": _EVA_8B_PLUS, 171 | } 172 | 173 | 174 | def _clean_tag(tag: str): 175 | # normalize pretrained tags 176 | return tag.lower().replace("-", "_") 177 | 178 | 179 | def list_pretrained(as_str: bool = False): 180 | """returns list of pretrained models 181 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 182 | """ 183 | return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 184 | 185 | 186 | def list_pretrained_models_by_tag(tag: str): 187 | """return all models having the specified pretrain tag""" 188 | models = [] 189 | tag = _clean_tag(tag) 190 | for k in _PRETRAINED.keys(): 191 | if tag in _PRETRAINED[k]: 192 | models.append(k) 193 | return models 194 | 195 | 196 | def list_pretrained_tags_by_model(model: str): 197 | """return all pretrain tags for the specified model architecture""" 198 | tags = [] 199 | if model in _PRETRAINED: 200 | tags.extend(_PRETRAINED[model].keys()) 201 | return tags 202 | 203 | 204 | def is_pretrained_cfg(model: str, tag: str): 205 | if model not in _PRETRAINED: 206 | return False 207 | return _clean_tag(tag) in _PRETRAINED[model] 208 | 209 | 210 | def get_pretrained_cfg(model: str, tag: str): 211 | if model not in _PRETRAINED: 212 | return {} 213 | model_pretrained = _PRETRAINED[model] 214 | return model_pretrained.get(_clean_tag(tag), {}) 215 | 216 | 217 | def get_pretrained_url(model: str, tag: str): 218 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 219 | return cfg.get("url", "") 220 | 221 | 222 | def download_pretrained_from_url( 223 | url: str, 224 | cache_dir: Union[str, None] = None, 225 | ): 226 | if not cache_dir: 227 | cache_dir = os.path.expanduser("~/.cache/clip") 228 | os.makedirs(cache_dir, exist_ok=True) 229 | filename = os.path.basename(url) 230 | 231 | if "openaipublic" in url: 232 | expected_sha256 = url.split("/")[-2] 233 | elif "mlfoundations" in url: 234 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 235 | else: 236 | expected_sha256 = "" 237 | 238 | download_target = os.path.join(cache_dir, filename) 239 | 240 | if os.path.exists(download_target) and not os.path.isfile(download_target): 241 | raise RuntimeError(f"{download_target} exists and is not a regular file") 242 | 243 | if os.path.isfile(download_target): 244 | if expected_sha256: 245 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 246 | return download_target 247 | else: 248 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 249 | else: 250 | return download_target 251 | 252 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 253 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop: 254 | while True: 255 | buffer = source.read(8192) 256 | if not buffer: 257 | break 258 | 259 | output.write(buffer) 260 | loop.update(len(buffer)) 261 | 262 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 263 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 264 | 265 | return download_target 266 | 267 | 268 | def has_hf_hub(necessary=False): 269 | if not _has_hf_hub and necessary: 270 | # if no HF Hub module installed, and it is necessary to continue, raise error 271 | raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.") 272 | return _has_hf_hub 273 | 274 | 275 | def download_pretrained_from_hf( 276 | model_id: str, 277 | filename: str = "open_clip_pytorch_model.bin", 278 | revision=None, 279 | cache_dir: Union[str, None] = None, 280 | ): 281 | has_hf_hub(True) 282 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 283 | return cached_file 284 | 285 | 286 | def download_pretrained( 287 | cfg: Dict, 288 | force_hf_hub: bool = False, 289 | cache_dir: Union[str, None] = None, 290 | ): 291 | target = "" 292 | if not cfg: 293 | return target 294 | 295 | download_url = cfg.get("url", "") 296 | download_hf_hub = cfg.get("hf_hub", "") 297 | if download_hf_hub and force_hf_hub: 298 | # use HF hub even if url exists 299 | download_url = "" 300 | 301 | if download_url: 302 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 303 | elif download_hf_hub: 304 | has_hf_hub(True) 305 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 306 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 307 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 308 | model_id, filename = os.path.split(download_hf_hub) 309 | if filename: 310 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) 311 | else: 312 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 313 | 314 | return target 315 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | 8 | def broadcat(tensors, dim=-1): 9 | num_tensors = len(tensors) 10 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 11 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" 12 | shape_len = list(shape_lens)[0] 13 | dim = (dim + shape_len) if dim < 0 else dim 14 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 15 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 16 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" 17 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 18 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 19 | expanded_dims.insert(dim, (dim, dims[dim])) 20 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 21 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 22 | return torch.cat(tensors, dim=dim) 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, "... (d r) -> ... d r", r=2) 27 | x1, x2 = x.unbind(dim=-1) 28 | x = torch.stack((-x2, x1), dim=-1) 29 | return rearrange(x, "... d r -> ... (d r)") 30 | 31 | 32 | class VisionRotaryEmbedding(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | pt_seq_len, 37 | ft_seq_len=None, 38 | custom_freqs=None, 39 | freqs_for="lang", 40 | theta=10000, 41 | max_freq=10, 42 | num_freqs=1, 43 | ): 44 | super().__init__() 45 | if custom_freqs: 46 | freqs = custom_freqs 47 | elif freqs_for == "lang": 48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 49 | elif freqs_for == "pixel": 50 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 51 | elif freqs_for == "constant": 52 | freqs = torch.ones(num_freqs).float() 53 | else: 54 | raise ValueError(f"unknown modality {freqs_for}") 55 | 56 | if ft_seq_len is None: 57 | ft_seq_len = pt_seq_len 58 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 59 | 60 | freqs_h = torch.einsum("..., f -> ... f", t, freqs) 61 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) 62 | 63 | freqs_w = torch.einsum("..., f -> ... f", t, freqs) 64 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) 65 | 66 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) 67 | 68 | self.register_buffer("freqs_cos", freqs.cos()) 69 | self.register_buffer("freqs_sin", freqs.sin()) 70 | 71 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 72 | 73 | def forward(self, t, start_index=0): 74 | rot_dim = self.freqs_cos.shape[-1] 75 | end_index = start_index + rot_dim 76 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 78 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 79 | 80 | return torch.cat((t_left, t, t_right), dim=-1) 81 | 82 | 83 | class VisionRotaryEmbeddingFast(nn.Module): 84 | def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): 85 | super().__init__() 86 | if custom_freqs: 87 | freqs = custom_freqs 88 | elif freqs_for == "lang": 89 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 90 | elif freqs_for == "pixel": 91 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 92 | elif freqs_for == "constant": 93 | freqs = torch.ones(num_freqs).float() 94 | else: 95 | raise ValueError(f"unknown modality {freqs_for}") 96 | 97 | if ft_seq_len is None: 98 | ft_seq_len = pt_seq_len 99 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 100 | 101 | freqs = torch.einsum("..., f -> ... f", t, freqs) 102 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 103 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 104 | 105 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 106 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 107 | 108 | self.patch_dropout = patch_dropout 109 | 110 | self.register_buffer("freqs_cos", freqs_cos) 111 | self.register_buffer("freqs_sin", freqs_sin) 112 | 113 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 114 | 115 | def forward(self, t, patch_indices_keep=None): 116 | if patch_indices_keep is not None: 117 | batch = t.size()[0] 118 | batch_indices = torch.arange(batch) 119 | batch_indices = batch_indices[..., None] 120 | 121 | freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 122 | freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 123 | 124 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 125 | freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") 126 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 127 | freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") 128 | 129 | return t * freqs_cos + rotate_half(t) * freqs_sin 130 | 131 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 132 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | import timm 14 | from timm.models.layers import Mlp, to_2tuple 15 | 16 | try: 17 | # old timm imports < 0.8.1 18 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 19 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 20 | except ImportError: 21 | # new timm imports >= 0.8.1 22 | from timm.layers import RotAttentionPool2d 23 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 24 | except ImportError: 25 | timm = None 26 | 27 | from .utils import freeze_batch_norm_2d 28 | 29 | 30 | class TimmModel(nn.Module): 31 | """timm model adapter 32 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 33 | """ 34 | 35 | def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): 36 | super().__init__() 37 | if timm is None: 38 | raise RuntimeError("Please `pip install timm` to use timm models.") 39 | 40 | self.image_size = to_2tuple(image_size) 41 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 42 | feat_size = self.trunk.default_cfg.get("pool_size", None) 43 | feature_ndim = 1 if not feat_size else 2 44 | if pool in ("abs_attn", "rot_attn"): 45 | assert feature_ndim == 2 46 | # if attn pooling used, remove both classifier and default pool 47 | self.trunk.reset_classifier(0, global_pool="") 48 | else: 49 | # reset global pool if pool config set, otherwise leave as network default 50 | reset_kwargs = dict(global_pool=pool) if pool else {} 51 | self.trunk.reset_classifier(0, **reset_kwargs) 52 | prev_chs = self.trunk.num_features 53 | 54 | head_layers = OrderedDict() 55 | if pool == "abs_attn": 56 | head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 57 | prev_chs = embed_dim 58 | elif pool == "rot_attn": 59 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 60 | prev_chs = embed_dim 61 | else: 62 | assert proj, "projection layer needed if non-attention pooling is used." 63 | 64 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 65 | if proj == "linear": 66 | head_layers["drop"] = nn.Dropout(drop) 67 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 68 | elif proj == "mlp": 69 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 70 | 71 | self.head = nn.Sequential(head_layers) 72 | 73 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 74 | """lock modules 75 | Args: 76 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 77 | """ 78 | if not unlocked_groups: 79 | # lock full model 80 | for param in self.trunk.parameters(): 81 | param.requires_grad = False 82 | if freeze_bn_stats: 83 | freeze_batch_norm_2d(self.trunk) 84 | else: 85 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 86 | try: 87 | # FIXME import here until API stable and in an official release 88 | from timm.models.helpers import group_parameters, group_modules 89 | except ImportError: 90 | raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") 91 | matcher = self.trunk.group_matcher() 92 | gparams = group_parameters(self.trunk, matcher) 93 | max_layer_id = max(gparams.keys()) 94 | max_layer_id = max_layer_id - unlocked_groups 95 | for group_idx in range(max_layer_id + 1): 96 | group = gparams[group_idx] 97 | for param in group: 98 | self.trunk.get_parameter(param).requires_grad = False 99 | if freeze_bn_stats: 100 | gmodules = group_modules(self.trunk, matcher, reverse=True) 101 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 102 | freeze_batch_norm_2d(self.trunk, gmodules) 103 | 104 | @torch.jit.ignore 105 | def set_grad_checkpointing(self, enable=True): 106 | try: 107 | self.trunk.set_grad_checkpointing(enable) 108 | except Exception as e: 109 | logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") 110 | 111 | def forward(self, x): 112 | x = self.trunk(x) 113 | x = self.head(x) 114 | return x 115 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import gzip 7 | import html 8 | import os 9 | from functools import lru_cache 10 | from typing import Union, List 11 | 12 | import ftfy 13 | import regex as re 14 | import torch 15 | 16 | # https://stackoverflow.com/q/62691279 17 | import os 18 | 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | 22 | @lru_cache() 23 | def default_bpe(): 24 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 25 | 26 | 27 | @lru_cache() 28 | def bytes_to_unicode(): 29 | """ 30 | Returns list of utf-8 byte and a corresponding list of unicode strings. 31 | The reversible bpe codes work on unicode strings. 32 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 33 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 34 | This is a signficant percentage of your normal, say, 32K bpe vocab. 35 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 36 | And avoids mapping to whitespace/control characters the bpe code barfs on. 37 | """ 38 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 39 | cs = bs[:] 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | bs.append(b) 44 | cs.append(2**8 + n) 45 | n += 1 46 | cs = [chr(n) for n in cs] 47 | return dict(zip(bs, cs)) 48 | 49 | 50 | def get_pairs(word): 51 | """Return set of symbol pairs in a word. 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | def basic_clean(text): 63 | text = ftfy.fix_text(text) 64 | text = html.unescape(html.unescape(text)) 65 | return text.strip() 66 | 67 | 68 | def whitespace_clean(text): 69 | text = re.sub(r"\s+", " ", text) 70 | text = text.strip() 71 | return text 72 | 73 | 74 | class SimpleTokenizer(object): 75 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 76 | self.byte_encoder = bytes_to_unicode() 77 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 78 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 79 | merges = merges[1 : 49152 - 256 - 2 + 1] 80 | merges = [tuple(merge.split()) for merge in merges] 81 | vocab = list(bytes_to_unicode().values()) 82 | vocab = vocab + [v + "" for v in vocab] 83 | for merge in merges: 84 | vocab.append("".join(merge)) 85 | if not special_tokens: 86 | special_tokens = ["", ""] 87 | else: 88 | special_tokens = ["", ""] + special_tokens 89 | vocab.extend(special_tokens) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {t: t for t in special_tokens} 94 | special = "|".join(special_tokens) 95 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 96 | 97 | self.vocab_size = len(self.encoder) 98 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 99 | 100 | def bpe(self, token): 101 | if token in self.cache: 102 | return self.cache[token] 103 | word = tuple(token[:-1]) + (token[-1] + "",) 104 | pairs = get_pairs(word) 105 | 106 | if not pairs: 107 | return token + "" 108 | 109 | while True: 110 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 111 | if bigram not in self.bpe_ranks: 112 | break 113 | first, second = bigram 114 | new_word = [] 115 | i = 0 116 | while i < len(word): 117 | try: 118 | j = word.index(first, i) 119 | new_word.extend(word[i:j]) 120 | i = j 121 | except: 122 | new_word.extend(word[i:]) 123 | break 124 | 125 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 126 | new_word.append(first + second) 127 | i += 2 128 | else: 129 | new_word.append(word[i]) 130 | i += 1 131 | new_word = tuple(new_word) 132 | word = new_word 133 | if len(word) == 1: 134 | break 135 | else: 136 | pairs = get_pairs(word) 137 | word = " ".join(word) 138 | self.cache[token] = word 139 | return word 140 | 141 | def encode(self, text): 142 | bpe_tokens = [] 143 | text = whitespace_clean(basic_clean(text)).lower() 144 | for token in re.findall(self.pat, text): 145 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 146 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 147 | return bpe_tokens 148 | 149 | def decode(self, tokens): 150 | text = "".join([self.decoder[token] for token in tokens]) 151 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") 152 | return text 153 | 154 | 155 | _tokenizer = SimpleTokenizer() 156 | 157 | 158 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 159 | """ 160 | Returns the tokenized representation of given input string(s) 161 | 162 | Parameters 163 | ---------- 164 | texts : Union[str, List[str]] 165 | An input string or a list of input strings to tokenize 166 | context_length : int 167 | The context length to use; all CLIP models use 77 as the context length 168 | 169 | Returns 170 | ------- 171 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 172 | """ 173 | if isinstance(texts, str): 174 | texts = [texts] 175 | 176 | sot_token = _tokenizer.encoder[""] 177 | eot_token = _tokenizer.encoder[""] 178 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 179 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 180 | 181 | for i, tokens in enumerate(all_tokens): 182 | if len(tokens) > context_length: 183 | tokens = tokens[:context_length] # Truncate 184 | tokens[-1] = eot_token 185 | result[i, : len(tokens)] = torch.tensor(tokens) 186 | 187 | return result 188 | 189 | 190 | class HFTokenizer: 191 | "HuggingFace tokenizer wrapper" 192 | 193 | def __init__(self, tokenizer_name: str): 194 | from transformers import AutoTokenizer 195 | 196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 197 | 198 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 199 | # same cleaning as for default tokenizer, except lowercasing 200 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 201 | if isinstance(texts, str): 202 | texts = [texts] 203 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 204 | input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids 205 | return input_ids 206 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == "min" else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert("RGB") 40 | 41 | 42 | # class CatGen(nn.Module): 43 | # def __init__(self, num=4): 44 | # self.num = num 45 | # def mixgen_batch(image, text): 46 | # batch_size = image.shape[0] 47 | # index = np.random.permutation(batch_size) 48 | 49 | # cat_images = [] 50 | # for i in range(batch_size): 51 | # # image mixup 52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 53 | # # text concat 54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 55 | # text = torch.stack(text) 56 | # return image, text 57 | 58 | 59 | def image_transform( 60 | image_size: int, 61 | is_train: bool, 62 | mean: Optional[Tuple[float, ...]] = None, 63 | std: Optional[Tuple[float, ...]] = None, 64 | resize_longest_max: bool = False, 65 | fill_color: int = 0, 66 | ): 67 | mean = mean or OPENAI_DATASET_MEAN 68 | if not isinstance(mean, (list, tuple)): 69 | mean = (mean,) * 3 70 | 71 | std = std or OPENAI_DATASET_STD 72 | if not isinstance(std, (list, tuple)): 73 | std = (std,) * 3 74 | 75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 77 | image_size = image_size[0] 78 | 79 | normalize = Normalize(mean=mean, std=std) 80 | if is_train: 81 | return Compose( 82 | [ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ] 88 | ) 89 | else: 90 | if resize_longest_max: 91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 92 | else: 93 | transforms = [ 94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 95 | CenterCrop(image_size), 96 | ] 97 | transforms.extend( 98 | [ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ] 103 | ) 104 | return Compose(transforms) 105 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py: -------------------------------------------------------------------------------- 1 | # Based on EVA, BEIT, timm and DeiT code bases 2 | # https://github.com/baaivision/EVA 3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | # not tested yet 9 | import math 10 | from transformers import CLIPImageProcessor 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as checkpoint 16 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 17 | from .eva_clip import create_model_and_transforms, get_model_config 18 | import torch 19 | import torchvision 20 | import time 21 | 22 | from ....utils import rank0_print 23 | 24 | 25 | class EvaViTWrapper(nn.Module): 26 | def __init__(self, vision_tower, args, delay_load=False): 27 | super().__init__() 28 | 29 | self.is_loaded = False 30 | self.vision_tower_name = vision_tower 31 | self.pretrained = args.vision_tower_pretrained 32 | self.args = args 33 | 34 | self.select_layer = args.mm_vision_select_layer 35 | if self.select_layer < -1: 36 | self.select_layer += 1 37 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 38 | 39 | self.model_config = get_model_config(self.vision_tower_name) 40 | 41 | if not delay_load: 42 | rank0_print(f"Loading vision tower: {vision_tower}") 43 | self.load_model() 44 | elif getattr(args, "unfreeze_mm_vision_tower", False): 45 | # TODO: better detector is needed. 46 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 47 | self.load_model() 48 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 49 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 50 | self.load_model() 51 | 52 | def load_model(self): 53 | rank0_print(f"Loading: {self.vision_tower_name}") 54 | rank0_print(f"Pretrained: {self.pretrained}") 55 | time_start = time.time() 56 | model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") 57 | time_end = time.time() 58 | rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") 59 | self.device = next(model.parameters()).device 60 | self.dtype = next(model.parameters()).dtype 61 | if self.device.type != "meta": 62 | model = model.to("cuda") 63 | self.vision_tower = model.visual 64 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 65 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 66 | self.resize_transform_size = resize_transform.size 67 | self.image_processor = CLIPImageProcessor.from_pretrained( 68 | "openai/clip-vit-large-patch14", 69 | crop_size=resize_transform.size, 70 | size={"shortest_edge": resize_transform.size}, 71 | image_mean=list(normalize_transform.mean), 72 | image_std=list(normalize_transform.std), 73 | ) 74 | rank0_print(f"Loaded image processor: {self.image_processor}") 75 | self.vision_tower.requires_grad_(False) 76 | self.is_loaded = True 77 | 78 | def feature_select(self, image_features): 79 | select_feature_type = self.select_feature 80 | 81 | # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 82 | # select_every_k_layer = len(image_features) // 4 83 | # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) 84 | # select_feature_type = select_feature_type.replace("slicefour_", "") 85 | # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 86 | # select_layers = [-1, -4, -7, -10, 6] 87 | # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) 88 | # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 89 | # else: 90 | # image_features = image_features[self.select_layer] 91 | 92 | if select_feature_type == "patch": 93 | image_features = image_features[:, 1:] 94 | elif select_feature_type == "cls_patch": 95 | image_features = image_features 96 | else: 97 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 98 | return image_features 99 | 100 | def train(self, mode=True): 101 | self.training = mode 102 | 103 | if self.is_loaded: 104 | self.vision_tower.eval() 105 | 106 | def forward(self, images): 107 | if type(images) is list: 108 | image_features = [] 109 | for image in images: 110 | image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) 111 | image_features = self.feature_select(image_features).to(self.dtype) 112 | image_features.append(image_features) 113 | else: 114 | image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) 115 | image_features = self.feature_select(image_features).to(self.dtype) 116 | 117 | return image_features 118 | 119 | @property 120 | def dummy_feature(self): 121 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 122 | 123 | @property 124 | def hidden_size(self): 125 | return self.model_config["vision_cfg"]["width"] 126 | 127 | @property 128 | def num_patches(self): 129 | return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 130 | 131 | @property 132 | def num_patches_per_side(self): 133 | return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] 134 | 135 | @property 136 | def config(self): 137 | return self.model_config 138 | 139 | @property 140 | def image_size(self): 141 | return self.model_config["vision_cfg"]["image_size"] 142 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import EVAEncoderWrapper 6 | from .factory import list_models, add_model_config, get_model_config 7 | 8 | from ....utils import rank0_print 9 | 10 | 11 | class EvaClipVisionTower(nn.Module): 12 | def __init__(self, vision_tower, args, delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | self.vision_tower_name = vision_tower 17 | self.vision_tower_pretrained = args.vision_tower_pretrained 18 | self.config = get_model_config(vision_tower) 19 | 20 | if not delay_load: 21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") 22 | self.load_model() 23 | elif getattr(args, "unfreeze_mm_vision_tower", False): 24 | # TODO: better detector is needed. 25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 26 | self.load_model() 27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 29 | self.load_model() 30 | else: 31 | self.cfg_only = self.config 32 | 33 | def load_model(self, device_map=None): 34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}") 35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) 36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) 37 | rank0_print(f"Loaded image processor: {self.image_processor}") 38 | self.vision_tower.requires_grad_(False) 39 | self.is_loaded = True 40 | 41 | def forward(self, images): 42 | if type(images) is list: 43 | image_features = [] 44 | for image in images: 45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 49 | 50 | return image_features 51 | 52 | @property 53 | def dtype(self): 54 | return self.vision_tower.dtype 55 | 56 | @property 57 | def device(self): 58 | return self.vision_tower.device 59 | 60 | @property 61 | def hidden_size(self): 62 | return self.config["vision_cfg"]["width"] 63 | 64 | @property 65 | def num_patches(self): 66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 67 | 68 | @property 69 | def num_patches_per_side(self): 70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] 71 | 72 | @property 73 | def image_size(self): 74 | return self.config["vision_cfg"]["image_size"] 75 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | """ 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {"height": self.image_size, "width": self.image_size} 69 | 70 | @property 71 | def size(self): 72 | return {"shortest_edge": self.image_size} 73 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 13 | 14 | 15 | def _natural_key(string_): 16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 17 | 18 | 19 | def _rescan_model_configs(): 20 | global _MODEL_CONFIGS 21 | 22 | config_ext = (".json",) 23 | config_files = [] 24 | for config_path in _MODEL_CONFIG_PATHS: 25 | if config_path.is_file() and config_path.suffix in config_ext: 26 | config_files.append(config_path) 27 | elif config_path.is_dir(): 28 | for ext in config_ext: 29 | config_files.extend(config_path.glob(f"*{ext}")) 30 | 31 | for cf in config_files: 32 | with open(cf, "r", encoding="utf8") as f: 33 | model_cfg = json.load(f) 34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): 35 | _MODEL_CONFIGS[cf.stem] = model_cfg 36 | 37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 38 | 39 | 40 | _rescan_model_configs() # initial populate of model config registry 41 | 42 | 43 | def list_models(): 44 | """enumerate available model architectures based on config files""" 45 | return list(_MODEL_CONFIGS.keys()) 46 | 47 | 48 | def add_model_config(path): 49 | """add model config path or file and update registry""" 50 | if not isinstance(path, Path): 51 | path = Path(path) 52 | _MODEL_CONFIG_PATHS.append(path) 53 | _rescan_model_configs() 54 | 55 | 56 | def get_model_config(model_name): 57 | if model_name in _MODEL_CONFIGS: 58 | return deepcopy(_MODEL_CONFIGS[model_name]) 59 | else: 60 | return None 61 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/hf_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor 5 | from ...utils import rank0_print 6 | 7 | 8 | class HFVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1) 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | try: 25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) 26 | except Exception as e: 27 | if "448" in self.vision_tower_name: 28 | image_size = 448 29 | # use image processor with conig 30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) 31 | else: 32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | rank0_print(f"Loaded image processor: {self.image_processor}") 34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") 35 | self.device = self.vision_tower.device 36 | self.dtype = self.vision_tower.dtype 37 | self.config = self.vision_tower.config 38 | 39 | if hasattr(self.vision_tower, "vision_model"): 40 | self.vision_tower = self.vision_tower.vision_model 41 | self.vision_tower.requires_grad_(False) 42 | # self.vision_tower.eval() 43 | self.is_loaded = True 44 | 45 | def feature_select(self, image_forward_outs): 46 | select_feature_type = self.select_feature 47 | 48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 51 | select_feature_type = select_feature_type.replace("slicefour_", "") 52 | else: 53 | image_features = image_forward_outs.hidden_states[self.select_layer] 54 | 55 | if select_feature_type == "patch": 56 | image_features = image_features[:, 1:] 57 | elif select_feature_type == "cls_patch": 58 | image_features = image_features 59 | else: 60 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 61 | return image_features 62 | 63 | def forward(self, images): 64 | if type(images) is list: 65 | image_features = [] 66 | for image in images: 67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 68 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 69 | image_features.append(image_feature) 70 | else: 71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 72 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 73 | 74 | return image_features 75 | 76 | @property 77 | def dummy_feature(self): 78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 79 | 80 | # @property 81 | # def dtype(self): 82 | # return self.vision_tower.dtype 83 | 84 | # @property 85 | # def device(self): 86 | # return self.vision_tower.device 87 | 88 | @property 89 | def hidden_size(self): 90 | try: 91 | _hidden_size = self.config.hidden_size 92 | except: 93 | _hidden_size = self.config.vision_config.hidden_size 94 | if "slicefour" in self.select_feature: 95 | _hidden_size *= 4 96 | return _hidden_size 97 | 98 | @property 99 | def num_patches(self): 100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 101 | if "cls_patch" in self.select_feature: 102 | _num_patches += 1 103 | return _num_patches 104 | 105 | @property 106 | def num_patches_per_side(self): 107 | return self.config.image_size // self.config.patch_size 108 | 109 | @property 110 | def image_size(self): 111 | return self.config.image_size 112 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/open_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import CLIPImageProcessor 4 | from ...utils import rank0_print 5 | 6 | try: 7 | import open_clip 8 | import torchvision 9 | from open_clip.transformer import _expand_token 10 | except ImportError: 11 | print("OpenCLIP not installed") 12 | open_clip = None 13 | 14 | HIDDEN_SIZE_DICT = { 15 | "ViT-H-14-378-quickgelu": 1280, 16 | } 17 | 18 | 19 | class OpenCLIPVisionTower(nn.Module): 20 | def __init__(self, vision_tower, args, delay_load=False): 21 | super().__init__() 22 | 23 | self.is_loaded = False 24 | self.model_name = vision_tower.replace("open_clip_hub:", "") 25 | self.pretrained = args.vision_tower_pretrained 26 | self.select_layer = args.mm_vision_select_layer 27 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 28 | 29 | if not delay_load: 30 | rank0_print(f"Loading vision tower: {vision_tower}") 31 | self.load_model() 32 | elif getattr(args, "unfreeze_mm_vision_tower", False): 33 | # TODO: better detector is needed. 34 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 35 | self.load_model() 36 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 37 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 38 | self.load_model() 39 | 40 | def load_model(self, device_map="auto"): 41 | rank0_print(f"Loading OpenCLIP model: {self.model_name}") 42 | rank0_print(f"Pretrained: {self.pretrained}") 43 | vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") 44 | 45 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 46 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 47 | self.resize_transform_size = resize_transform.size # 224 or 384 48 | self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 49 | 50 | self.image_processor = CLIPImageProcessor.from_pretrained( 51 | "openai/clip-vit-large-patch14", 52 | crop_size=resize_transform.size, 53 | size={"shortest_edge": resize_transform.size}, 54 | image_mean=list(normalize_transform.mean), 55 | image_std=list(normalize_transform.std), 56 | ) 57 | rank0_print(f"Loaded image processor: {self.image_processor}") 58 | self.vision_tower = vision_tower.visual 59 | self.vision_tower.requires_grad_(False) 60 | 61 | self.is_loaded = True 62 | 63 | def feature_select(self, image_forward_outs): 64 | image_features = image_forward_outs[self.select_layer] 65 | if self.select_feature == "patch": 66 | image_features = image_features[:, 1:] 67 | elif self.select_feature == "cls_patch": 68 | image_features = image_features 69 | elif self.select_feature == "conv_flatten": 70 | image_features = image_features.flatten(2).transpose(1, 2) 71 | else: 72 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 73 | return image_features 74 | 75 | def forward_visual(self, x, output_hidden_states=False): 76 | if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): 77 | return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) 78 | else: 79 | 80 | def forward_openclip(self, x: torch.Tensor): 81 | features = [] 82 | x = self.conv1(x) # shape = [*, width, grid, grid] 83 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 84 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 85 | 86 | # class embeddings and positional embeddings 87 | x = torch.cat( 88 | [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], 89 | dim=1, 90 | ) 91 | # shape = [*, grid ** 2 + 1, width] 92 | x = x + self.positional_embedding.to(x.dtype) 93 | 94 | x = self.patch_dropout(x) 95 | x = self.ln_pre(x) 96 | 97 | x = x.permute(1, 0, 2) # NLD -> LND 98 | for r in self.transformer.resblocks: 99 | x = r(x, attn_mask=None) 100 | features.append(x) 101 | return features 102 | 103 | return forward_openclip(self.vision_tower, x) 104 | 105 | def forward(self, images): 106 | if type(images) is list: 107 | image_features = [] 108 | for image in images: 109 | image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) 110 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 111 | image_features.append(image_feature) 112 | else: 113 | image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) 114 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 115 | 116 | return image_features 117 | 118 | @property 119 | def dummy_feature(self): 120 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 121 | 122 | @property 123 | def dtype(self): 124 | if hasattr(self.vision_tower, "conv1"): 125 | return self.vision_tower.conv1.weight.dtype 126 | if hasattr(self.vision_tower, "trunk"): 127 | return self.vision_tower.trunk.patch_embed.proj.weight.dtype 128 | raise NotImplementedError 129 | 130 | @property 131 | def device(self): 132 | if hasattr(self.vision_tower, "conv1"): 133 | return self.vision_tower.conv1.weight.device 134 | if hasattr(self.vision_tower, "trunk"): 135 | return self.vision_tower.trunk.patch_embed.proj.weight.device 136 | raise NotImplementedError 137 | 138 | @property 139 | def config(self): 140 | return None 141 | 142 | @property 143 | def hidden_size(self): 144 | if self.model_name in HIDDEN_SIZE_DICT: 145 | return HIDDEN_SIZE_DICT[self.model_name] 146 | else: 147 | raise NotImplementedError 148 | 149 | @property 150 | def num_patches(self): 151 | image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] 152 | _num_patches = (image_size // self.patch_size) ** 2 153 | if "cls_patch" in self.select_feature: 154 | _num_patches += 1 155 | return _num_patches 156 | 157 | @property 158 | def image_size(self): 159 | return self.resize_transform_size 160 | 161 | @property 162 | def num_patches_per_side(self): 163 | return self.resize_transform_size // self.patch_size 164 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | import requests 9 | 10 | from .constants import LOGDIR 11 | 12 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content." 14 | 15 | handler = None 16 | 17 | import torch.distributed as dist 18 | 19 | try: 20 | import av 21 | from decord import VideoReader, cpu 22 | except ImportError: 23 | print("Please install pyav to use video processing functions.") 24 | 25 | def process_video_with_decord(video_file, data_args): 26 | vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) 27 | total_frame_num = len(vr) 28 | avg_fps = round(vr.get_avg_fps() / data_args.video_fps) 29 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)] 30 | 31 | if data_args.frames_upbound > 0: 32 | if len(frame_idx) > data_args.frames_upbound: 33 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) 34 | frame_idx = uniform_sampled_frames.tolist() 35 | 36 | video = vr.get_batch(frame_idx).asnumpy() 37 | # https://github.com/dmlc/decord/issues/208 38 | vr.seek(0) 39 | return video 40 | 41 | def process_video_with_pyav(video_file, data_args): 42 | container = av.open(video_file) 43 | # !!! This is the only difference. Using auto threading 44 | container.streams.video[0].thread_type = "AUTO" 45 | 46 | video_frames = [] 47 | for packet in container.demux(): 48 | if packet.stream.type == 'video': 49 | for frame in packet.decode(): 50 | video_frames.append(frame) 51 | total_frame_num = len(video_frames) 52 | video_time = video_frames[-1].time 53 | avg_fps = round(total_frame_num / video_time / data_args.video_fps) 54 | frame_idx = [i for i in range(0, total_frame_num, avg_fps)] 55 | 56 | if data_args.frames_upbound > 0: 57 | if len(frame_idx) > data_args.frames_upbound: 58 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) 59 | frame_idx = uniform_sampled_frames.tolist() 60 | 61 | 62 | frames = [video_frames[i] for i in frame_idx] 63 | return np.stack([x.to_ndarray(format="rgb24") for x in frames]) 64 | 65 | 66 | def rank0_print(*args): 67 | if dist.is_initialized(): 68 | if dist.get_rank() == 0: 69 | print(f"Rank {dist.get_rank()}: ", *args) 70 | else: 71 | print(*args) 72 | 73 | 74 | def rank_print(*args): 75 | if dist.is_initialized(): 76 | print(f"Rank {dist.get_rank()}: ", *args) 77 | else: 78 | print(*args) 79 | 80 | def build_logger(logger_name, logger_filename): 81 | global handler 82 | 83 | formatter = logging.Formatter( 84 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 85 | datefmt="%Y-%m-%d %H:%M:%S", 86 | ) 87 | 88 | # Set the format of root handlers 89 | if not logging.getLogger().handlers: 90 | logging.basicConfig(level=logging.INFO) 91 | logging.getLogger().handlers[0].setFormatter(formatter) 92 | 93 | # Redirect stdout and stderr to loggers 94 | stdout_logger = logging.getLogger("stdout") 95 | stdout_logger.setLevel(logging.INFO) 96 | sl = StreamToLogger(stdout_logger, logging.INFO) 97 | sys.stdout = sl 98 | 99 | stderr_logger = logging.getLogger("stderr") 100 | stderr_logger.setLevel(logging.ERROR) 101 | sl = StreamToLogger(stderr_logger, logging.ERROR) 102 | sys.stderr = sl 103 | 104 | # Get logger 105 | logger = logging.getLogger(logger_name) 106 | logger.setLevel(logging.INFO) 107 | 108 | # Add a file handler for all loggers 109 | if handler is None: 110 | os.makedirs(LOGDIR, exist_ok=True) 111 | filename = os.path.join(LOGDIR, logger_filename) 112 | handler = logging.handlers.TimedRotatingFileHandler(filename, when="D", utc=True) 113 | handler.setFormatter(formatter) 114 | 115 | for name, item in logging.root.manager.loggerDict.items(): 116 | if isinstance(item, logging.Logger): 117 | item.addHandler(handler) 118 | 119 | return logger 120 | 121 | 122 | class StreamToLogger(object): 123 | """ 124 | Fake file-like stream object that redirects writes to a logger instance. 125 | """ 126 | 127 | def __init__(self, logger, log_level=logging.INFO): 128 | self.terminal = sys.stdout 129 | self.logger = logger 130 | self.log_level = log_level 131 | self.linebuf = "" 132 | 133 | def __getattr__(self, attr): 134 | return getattr(self.terminal, attr) 135 | 136 | def write(self, buf): 137 | temp_linebuf = self.linebuf + buf 138 | self.linebuf = "" 139 | for line in temp_linebuf.splitlines(True): 140 | # From the io.TextIOWrapper docs: 141 | # On output, if newline is None, any '\n' characters written 142 | # are translated to the system default line separator. 143 | # By default sys.stdout.write() expects '\n' newlines and then 144 | # translates them so this is still cross platform. 145 | if line[-1] == "\n": 146 | self.logger.log(self.log_level, line.rstrip()) 147 | else: 148 | self.linebuf += line 149 | 150 | def flush(self): 151 | if self.linebuf != "": 152 | self.logger.log(self.log_level, self.linebuf.rstrip()) 153 | self.linebuf = "" 154 | 155 | 156 | def disable_torch_init(): 157 | """ 158 | Disable the redundant torch default initialization to accelerate model creation. 159 | """ 160 | import torch 161 | 162 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 163 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 164 | 165 | 166 | def violates_moderation(text): 167 | """ 168 | Check whether the text violates OpenAI moderation API. 169 | """ 170 | url = "https://api.openai.com/v1/moderations" 171 | headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 172 | text = text.replace("\n", "") 173 | data = "{" + '"input": ' + f'"{text}"' + "}" 174 | data = data.encode("utf-8") 175 | try: 176 | ret = requests.post(url, headers=headers, data=data, timeout=5) 177 | flagged = ret.json()["results"][0]["flagged"] 178 | except requests.exceptions.RequestException as e: 179 | print(f"######################### Moderation Error: {e} #########################") 180 | flagged = False 181 | except KeyError as e: 182 | print(f"######################### Moderation Error: {e} #########################") 183 | flagged = False 184 | 185 | return flagged 186 | 187 | 188 | def pretty_print_semaphore(semaphore): 189 | if semaphore is None: 190 | return "None" 191 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 192 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub 2 | pillow 3 | transformers --------------------------------------------------------------------------------