├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── docs ├── data.png ├── mmrlhf_title.webp ├── reward_model.png └── teaser.png ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── cal_performance_mmreweard_bench.py │ ├── eval_mm_reward_bench.py │ ├── evaluate_interleave.py │ └── model_vqa.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_gemma.py │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mixtral.py │ │ ├── llava_mpt.py │ │ ├── llava_qwen.py │ │ ├── llava_qwen_moe.py │ │ └── modeling_llama.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── dev_eva_clip │ │ │ ├── eva_clip │ │ │ │ ├── __init__.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── constants.py │ │ │ │ ├── eva_vit_model.py │ │ │ │ ├── factory.py │ │ │ │ ├── hf_configs.py │ │ │ │ ├── hf_model.py │ │ │ │ ├── loss.py │ │ │ │ ├── model.py │ │ │ │ ├── model_configs │ │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ │ │ ├── modified_resnet.py │ │ │ │ ├── openai.py │ │ │ │ ├── pretrained.py │ │ │ │ ├── rope.py │ │ │ │ ├── timm_model.py │ │ │ │ ├── tokenizer.py │ │ │ │ ├── transform.py │ │ │ │ ├── transformer.py │ │ │ │ └── utils.py │ │ │ └── eva_vit.py │ │ ├── eva_clip │ │ │ ├── eva_clip_encoder.py │ │ │ ├── eva_clip_processors.py │ │ │ ├── eva_vit.py │ │ │ ├── factory.py │ │ │ └── model_configs │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14-336.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ └── Internal-EVA02-CLIP-10B-14.json │ │ ├── hf_vision.py │ │ ├── imagebind.py │ │ ├── open_clip_encoder.py │ │ └── siglip_encoder.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── masked_drop.py │ │ ├── perceiver.py │ │ ├── qformer.py │ │ └── spatial_pool.py │ └── utils.py ├── rm_process │ ├── download_pairwise_data.py │ ├── model_generation.py │ └── reward_generate.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_multi_image.py │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── llava_trainer_eval.py │ ├── train.py │ ├── train_dpo.py │ └── train_mem.py └── utils.py ├── predict.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── convert_to_swift.py ├── swift.sh ├── train │ ├── README.md │ ├── critic_reward_7b.sh │ ├── dpo_ov7b.sh │ └── generate_ref_logits.sh ├── zero2.json ├── zero2_fused_adamw.json ├── zero2_offload.json ├── zero3.json ├── zero3_offload.json └── zero3pp.json └── trl ├── __init__.py ├── core.py ├── environment ├── __init__.py └── base_environment.py ├── extras ├── __init__.py ├── best_of_n_sampler.py └── dataset_formatting.py ├── import_utils.py ├── models ├── __init__.py ├── modeling_base.py ├── modeling_sd_base.py ├── modeling_value_head.py └── utils.py └── trainer ├── __init__.py ├── base.py ├── ddpo_config.py ├── ddpo_trainer.py ├── dpo_mix_trainer.py ├── dpo_trainer.py ├── iterative_sft_trainer.py ├── model_config.py ├── ppo_config.py ├── ppo_trainer.py ├── reward_config.py ├── reward_trainer.py ├── sft_trainer.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | # Exclude some weights 20 | /openai 21 | /liuhaotian 22 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | # Unix-style newlines with a newline ending every file 4 | [*] 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.{py,json}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | # 2 space indentation 16 | [*.{md,sh,yaml,yml}] 17 | indent_style = space 18 | indent_size = 2 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | # *.json 11 | # *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | # Editor 16 | .idea 17 | *.swp 18 | .vscode 19 | tmp/ 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | scripts/alignment 25 | scripts/interleave 26 | scripts/video 27 | scripts/archived 28 | scripts/mm_rlhf 29 | scripts/video 30 | scripts/reward_model 31 | Safety 32 | scripts/eval.sh 33 | playground 34 | llavavid 35 | 36 | scripts/ref_reward_processed_data_3w_image_margin03.jsonl 37 | checkpoints 38 | upload_file_to_hub.py 39 | project_checkpoints 40 | debug_checkpoints 41 | playground/data 42 | playground/cc3m_llava34b_cap 43 | ckpts* 44 | 45 | .ipynb_checkpoints 46 | chunyl_scripts 47 | *.ipynb 48 | 49 | # DevContainer 50 | !.devcontainer/* 51 | 52 | # Demo 53 | serve_images/ 54 | notebooks/ 55 | logs 56 | scripts/dist_* 57 | logs/ 58 | submissions/ 59 | cn_scripts/ 60 | test.sh 61 | internal_project_checkpoints/ 62 | work_dirs 63 | scripts/i18n/* 64 | playground/.nfs028b000000010add00000001 65 | HIP 66 | playground/.nfs028b0000017bff2c00000012 67 | scripts/qwen 68 | scripts/vicuna 69 | scripts/mistral 70 | scripts/critic_reward_processed_data_3w_image_margin03_chunk0_of_1.jsonl 71 | scripts/baseline_rep 72 | scripts/cn_boli01_hl 73 | scripts/cn_boli01_lf 74 | scripts/cn_lf 75 | scripts/cn_lq 76 | scripts/cn_yg 77 | scripts/cn_yg_hao 78 | scripts/eva_encoder 79 | scripts/i18n 80 | scripts/i18n_higher_res 81 | scripts/multi-images 82 | scratchpad 83 | upload_file_to_hub.py 84 | build/ 85 | playground/*.json 86 | mlx_configs/ 87 | data_processing/ 88 | # demo/ 89 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /docs/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/docs/data.png -------------------------------------------------------------------------------- /docs/mmrlhf_title.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/docs/mmrlhf_title.webp -------------------------------------------------------------------------------- /docs/reward_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/docs/reward_model.png -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/docs/teaser.png -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/eval/cal_performance_mmreweard_bench.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Calculate metrics from JSONL data') 6 | parser.add_argument('input_file', type=str, help='Path to input JSONL file') 7 | args = parser.parse_args() 8 | 9 | input_file = args.input_file 10 | 11 | # 类别关键词 12 | category_keywords = ["mcq", "long", "short", "safety", "video"] 13 | 14 | # 初始化统计 15 | category_stats = {keyword: {"accuracy": 0, "acc_plus": 0, "total": 0} for keyword in category_keywords} 16 | overall_stats = {"accuracy": 0, "acc_plus": 0, "total": 0} 17 | 18 | # 用于存储每个id的items 19 | id_to_items = defaultdict(list) 20 | 21 | # 读取数据并分类 22 | with open(input_file, "r") as infile: 23 | for line in infile: 24 | item = json.loads(line.strip()) 25 | image_path = item.get("image", "") or item.get("video", "") 26 | item_id = item.get("id", "") 27 | id_to_items[item_id].append(item) 28 | 29 | # 分类到相应类别 30 | for keyword in category_keywords: 31 | if keyword in image_path: 32 | category_stats[keyword]["total"] += 1 33 | break 34 | 35 | # 更新总计 36 | overall_stats["total"] += 1 37 | 38 | # 计算accuracy和acc_plus 39 | for item_id, items in id_to_items.items(): 40 | # 统计单个id是否满足acc+ 41 | all_correct = True 42 | for item in items: 43 | reward_0 = item["rewards"][0] 44 | reward_1 = item["rewards"][1] 45 | correct = reward_0 > reward_1 46 | 47 | # 分类统计accuracy 48 | for keyword in category_keywords: 49 | if keyword in item.get("image", "") or keyword in item.get("video", ""): 50 | if correct: 51 | category_stats[keyword]["accuracy"] += 1 52 | else: 53 | all_correct = False 54 | break 55 | 56 | # 总体统计accuracy 57 | if correct: 58 | overall_stats["accuracy"] += 1 59 | else: 60 | all_correct = False 61 | 62 | # 更新acc+统计 63 | if all_correct: 64 | for keyword in category_keywords: 65 | if any(keyword in item.get("image", "") or keyword in item.get("video", "") for item in items): 66 | category_stats[keyword]["acc_plus"] += 1 67 | break 68 | overall_stats["acc_plus"] += 1 69 | 70 | # 计算每个类别的accuracy和acc+ 71 | for keyword, stats in category_stats.items(): 72 | if stats["total"] > 0: 73 | stats["accuracy"] = stats["accuracy"] / stats["total"] 74 | stats["acc_plus"] = stats["acc_plus"] / len( 75 | [item_id for item_id in id_to_items if any(keyword in (item.get("image", "") + item.get("video", "")) for item in id_to_items[item_id])] 76 | ) 77 | 78 | # 计算总体accuracy和acc+ 79 | if overall_stats["total"] > 0: 80 | overall_stats["accuracy"] = overall_stats["accuracy"] / overall_stats["total"] 81 | overall_stats["acc_plus"] = overall_stats["acc_plus"] / len(id_to_items) 82 | 83 | # 输出结果 84 | def print_metrics(): 85 | print("\nCategory-wise Metrics:") 86 | for keyword, stats in category_stats.items(): 87 | print(f"Category: {keyword}") 88 | print(f" Accuracy: {stats['accuracy']:.2f}") 89 | print(f" ACC+: {stats['acc_plus']:.2f}") 90 | print(f" Total: {stats['total']}") 91 | 92 | print("\nOverall Metrics:") 93 | print(f"Overall Accuracy: {overall_stats['accuracy']:.2f}") 94 | print(f"Overall ACC+: {overall_stats['acc_plus']:.2f}") 95 | print(f"Total Items: {overall_stats['total']}") 96 | 97 | # 输出 98 | print_metrics() 99 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", 7 | "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", 8 | # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", 9 | # Add other models as needed 10 | } 11 | 12 | for model_name, model_classes in AVAILABLE_MODELS.items(): 13 | try: 14 | exec(f"from .language_model.{model_name} import {model_classes}") 15 | except Exception as e: 16 | print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}") 17 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model import * 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaGemmaConfig(GemmaConfig): 31 | model_type = "llava_gemma" 32 | 33 | 34 | class LlavaGemmaModel(LlavaMetaModel, GemmaModel): 35 | config_class = LlavaGemmaConfig 36 | 37 | def __init__(self, config: GemmaConfig): 38 | super(LlavaGemmaModel, self).__init__(config) 39 | 40 | 41 | class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaGemmaConfig 43 | 44 | def __init__(self, config): 45 | super(GemmaForCausalLM, self).__init__(config) 46 | self.model = LlavaGemmaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | image_sizes: Optional[List[List[int]]] = None, 69 | return_dict: Optional[bool] = None, 70 | cache_position: Optional[torch.LongTensor] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 75 | 76 | return super().forward( 77 | input_ids=input_ids, 78 | attention_mask=attention_mask, 79 | position_ids=position_ids, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | labels=labels, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict, 87 | cache_position=cache_position, 88 | ) 89 | 90 | @torch.no_grad() 91 | def generate( 92 | self, 93 | inputs: Optional[torch.Tensor] = None, 94 | images: Optional[torch.Tensor] = None, 95 | image_sizes: Optional[torch.Tensor] = None, 96 | **kwargs, 97 | ) -> Union[GenerateOutput, torch.LongTensor]: 98 | position_ids = kwargs.pop("position_ids", None) 99 | attention_mask = kwargs.pop("attention_mask", None) 100 | if "inputs_embeds" in kwargs: 101 | raise NotImplementedError("`inputs_embeds` is not supported") 102 | 103 | if images is not None: 104 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 105 | else: 106 | inputs_embeds = self.get_model().embed_tokens(inputs) 107 | 108 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 109 | 110 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 111 | images = kwargs.pop("images", None) 112 | image_sizes = kwargs.pop("image_sizes", None) 113 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 114 | if images is not None: 115 | inputs["images"] = images 116 | if image_sizes is not None: 117 | inputs["image_sizes"] = image_sizes 118 | return inputs 119 | 120 | 121 | AutoConfig.register("llava_gemma", LlavaGemmaConfig) 122 | AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM) 123 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMistralConfig(MistralConfig): 31 | model_type = "llava_mistral" 32 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 33 | max_new_tokens: int = 1024 34 | do_sample: bool = False 35 | top_p: Optional[float] = None 36 | 37 | 38 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 39 | config_class = LlavaMistralConfig 40 | 41 | def __init__(self, config: MistralConfig): 42 | super(LlavaMistralModel, self).__init__(config) 43 | 44 | 45 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMistralConfig 47 | 48 | def __init__(self, config): 49 | super(MistralForCausalLM, self).__init__(config) 50 | 51 | config.model_type = "llava_mistral" 52 | config.rope_scaling = None 53 | 54 | self.model = LlavaMistralModel(config) 55 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | image_sizes: Optional[List[List[int]]] = None, 75 | return_dict: Optional[bool] = None, 76 | cache_position=None, 77 | ) -> Union[Tuple, CausalLMOutputWithPast]: 78 | 79 | if inputs_embeds is None: 80 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 81 | 82 | return super().forward( 83 | input_ids=input_ids, 84 | attention_mask=attention_mask, 85 | position_ids=position_ids, 86 | past_key_values=past_key_values, 87 | inputs_embeds=inputs_embeds, 88 | labels=labels, 89 | use_cache=use_cache, 90 | output_attentions=output_attentions, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | 95 | @torch.no_grad() 96 | def generate( 97 | self, 98 | inputs: Optional[torch.Tensor] = None, 99 | images: Optional[torch.Tensor] = None, 100 | image_sizes: Optional[torch.Tensor] = None, 101 | **kwargs, 102 | ) -> Union[GenerateOutput, torch.LongTensor]: 103 | position_ids = kwargs.pop("position_ids", None) 104 | attention_mask = kwargs.pop("attention_mask", None) 105 | if "inputs_embeds" in kwargs: 106 | raise NotImplementedError("`inputs_embeds` is not supported") 107 | 108 | if images is not None: 109 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 110 | else: 111 | inputs_embeds = self.get_model().embed_tokens(inputs) 112 | 113 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 114 | 115 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 116 | images = kwargs.pop("images", None) 117 | image_sizes = kwargs.pop("image_sizes", None) 118 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 119 | if images is not None: 120 | inputs["images"] = images 121 | if image_sizes is not None: 122 | inputs["image_sizes"] = image_sizes 123 | return inputs 124 | 125 | 126 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 127 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 128 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMixtralConfig(MixtralConfig): 31 | model_type = "llava_mixtral" 32 | 33 | 34 | class LlavaMixtralModel(LlavaMetaModel, MixtralModel): 35 | config_class = LlavaMixtralConfig 36 | 37 | def __init__(self, config: MixtralConfig): 38 | super(LlavaMixtralModel, self).__init__(config) 39 | 40 | 41 | class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaMixtralConfig 43 | 44 | def __init__(self, config): 45 | super(MixtralForCausalLM, self).__init__(config) 46 | 47 | config.model_type = "llava_mixtral" 48 | config.rope_scaling = None 49 | self.model = LlavaMixtralModel(config) 50 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | modalities: Optional[List[str]] = ["image"], 72 | dpo_forward: Optional[bool] = None, 73 | cache_position=None, 74 | ) -> Union[Tuple, CausalLMOutputWithPast]: 75 | 76 | if inputs_embeds is None: 77 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 78 | 79 | if dpo_forward: 80 | outputs = self.model( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask, 83 | position_ids=position_ids, 84 | past_key_values=past_key_values, 85 | inputs_embeds=inputs_embeds, 86 | use_cache=use_cache, 87 | output_attentions=output_attentions, 88 | output_hidden_states=output_hidden_states, 89 | return_dict=return_dict, 90 | ) 91 | 92 | hidden_states = outputs[0] 93 | logits = self.lm_head(hidden_states) 94 | return logits, labels 95 | 96 | else: 97 | return super().forward( 98 | input_ids=input_ids, 99 | attention_mask=attention_mask, 100 | position_ids=position_ids, 101 | past_key_values=past_key_values, 102 | inputs_embeds=inputs_embeds, 103 | labels=labels, 104 | use_cache=use_cache, 105 | output_attentions=output_attentions, 106 | output_hidden_states=output_hidden_states, 107 | return_dict=return_dict, 108 | ) 109 | 110 | @torch.no_grad() 111 | def generate( 112 | self, 113 | inputs: Optional[torch.Tensor] = None, 114 | images: Optional[torch.Tensor] = None, 115 | image_sizes: Optional[torch.Tensor] = None, 116 | modalities: Optional[List[str]] = ["image"], 117 | **kwargs, 118 | ) -> Union[GenerateOutput, torch.LongTensor]: 119 | position_ids = kwargs.pop("position_ids", None) 120 | attention_mask = kwargs.pop("attention_mask", None) 121 | if "inputs_embeds" in kwargs: 122 | raise NotImplementedError("`inputs_embeds` is not supported") 123 | 124 | if images is not None: 125 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 126 | else: 127 | inputs_embeds = self.get_model().embed_tokens(inputs) 128 | 129 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 130 | 131 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 132 | images = kwargs.pop("images", None) 133 | image_sizes = kwargs.pop("image_sizes", None) 134 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 135 | if images is not None: 136 | inputs["images"] = images 137 | if image_sizes is not None: 138 | inputs["image_sizes"] = image_sizes 139 | return inputs 140 | 141 | 142 | AutoConfig.register("llava_mixtral", LlavaMixtralConfig) 143 | AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM) 144 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig 21 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 22 | 23 | 24 | class LlavaMptConfig(MptConfig): 25 | model_type = "llava_mpt" 26 | 27 | 28 | class LlavaMptModel(LlavaMetaModel, MptModel): 29 | config_class = LlavaMptConfig 30 | 31 | def __init__(self, config: MptConfig): 32 | config.hidden_size = config.d_model 33 | super(LlavaMptModel, self).__init__(config) 34 | 35 | def embed_tokens(self, x): 36 | return self.wte(x) 37 | 38 | 39 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaMptConfig 41 | supports_gradient_checkpointing = True 42 | 43 | def __init__(self, config): 44 | super(MptForCausalLM, self).__init__(config) 45 | 46 | config.model_type = "llava_mpt" 47 | config.rope_scaling = None 48 | self.generation_config = GenerationConfig( 49 | temperature=0.0, 50 | max_new_tokens=1024, 51 | do_sample=False, 52 | top_p=None, 53 | ) 54 | 55 | self.transformer = LlavaMptModel(config) 56 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | 58 | # Initialize weights and apply final processing 59 | self.post_init() 60 | 61 | def get_model(self): 62 | return self.transformer 63 | 64 | def _set_gradient_checkpointing(self, module, value=False): 65 | if isinstance(module, LlavaMptModel): 66 | module.gradient_checkpointing = value 67 | 68 | def forward( 69 | self, 70 | input_ids: Optional[torch.LongTensor] = None, 71 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | inputs_embeds: Optional[torch.Tensor] = None, 74 | labels: Optional[torch.Tensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | cache_position=None, 80 | images=None, 81 | ): 82 | 83 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 84 | 85 | return super().forward( 86 | input_ids, 87 | past_key_values=past_key_values, 88 | attention_mask=attention_mask, 89 | inputs_embeds=inputs_embeds, 90 | labels=labels, 91 | use_cache=use_cache, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | ) 96 | 97 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 98 | images = kwargs.pop("images", None) 99 | _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 100 | _inputs["images"] = images 101 | return _inputs 102 | 103 | 104 | AutoConfig.register("llava_mpt", LlavaMptConfig) 105 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 106 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 34 | 35 | raise ValueError(f"Unknown vision tower: {vision_tower}") 36 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 8 | from .tokenizer import SimpleTokenizer, tokenize 9 | from .transform import image_transform 10 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings", 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings", 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens", 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings", 54 | }, 55 | "pooler": "mean_pooler", 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | 10 | has_distributed = True 11 | except ImportError: 12 | has_distributed = False 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy 20 | 21 | 22 | def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): 23 | assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." 24 | if use_horovod: 25 | assert hvd is not None, "Please install horovod" 26 | if gather_with_grad: 27 | all_image_features = hvd.allgather(image_features) 28 | all_text_features = hvd.allgather(text_features) 29 | else: 30 | with torch.no_grad(): 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | if not local_loss: 34 | # ensure grads for local rank when all_* features don't have a gradient 35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 37 | gathered_image_features[rank] = image_features 38 | gathered_text_features[rank] = text_features 39 | all_image_features = torch.cat(gathered_image_features, dim=0) 40 | all_text_features = torch.cat(gathered_text_features, dim=0) 41 | else: 42 | # We gather tensors from all gpus 43 | if gather_with_grad: 44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 46 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 47 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 48 | else: 49 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 50 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 51 | dist.all_gather(gathered_image_features, image_features) 52 | dist.all_gather(gathered_text_features, text_features) 53 | if not local_loss: 54 | # ensure grads for local rank when all_* features don't have a gradient 55 | gathered_image_features[rank] = image_features 56 | gathered_text_features[rank] = text_features 57 | all_image_features = torch.cat(gathered_image_features, dim=0) 58 | all_text_features = torch.cat(gathered_text_features, dim=0) 59 | 60 | return all_image_features, all_text_features 61 | 62 | 63 | class ClipLoss(nn.Module): 64 | 65 | def __init__( 66 | self, 67 | local_loss=False, 68 | gather_with_grad=False, 69 | cache_labels=False, 70 | rank=0, 71 | world_size=1, 72 | use_horovod=False, 73 | smoothing=0.0, 74 | ): 75 | super().__init__() 76 | self.local_loss = local_loss 77 | self.gather_with_grad = gather_with_grad 78 | self.cache_labels = cache_labels 79 | self.rank = rank 80 | self.world_size = world_size 81 | self.use_horovod = use_horovod 82 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale=1.0): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 92 | 93 | if self.local_loss: 94 | logits_per_image = logit_scale * image_features @ all_text_features.T 95 | logits_per_text = logit_scale * text_features @ all_image_features.T 96 | else: 97 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 98 | logits_per_text = logits_per_image.T 99 | else: 100 | logits_per_image = logit_scale * image_features @ text_features.T 101 | logits_per_text = logit_scale * text_features @ image_features.T 102 | # calculated ground-truth and cache if enabled 103 | num_logits = logits_per_image.shape[0] 104 | if self.prev_num_logits != num_logits or device not in self.labels: 105 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 106 | if self.world_size > 1 and self.local_loss: 107 | labels = labels + num_logits * self.rank 108 | if self.cache_labels: 109 | self.labels[device] = labels 110 | self.prev_num_logits = num_logits 111 | else: 112 | labels = self.labels[device] 113 | 114 | if self.label_smoothing_cross_entropy: 115 | total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 116 | else: 117 | total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 118 | 119 | acc = None 120 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 121 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 122 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 123 | return total_loss, acc 124 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag("openai") 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = "fp32" if device == "cpu" else "fp16" 56 | 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith("amp") or precision == "fp32": 87 | model.float() 88 | elif precision == "bf16": 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == "fp32": 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | 8 | def broadcat(tensors, dim=-1): 9 | num_tensors = len(tensors) 10 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 11 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" 12 | shape_len = list(shape_lens)[0] 13 | dim = (dim + shape_len) if dim < 0 else dim 14 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 15 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 16 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" 17 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 18 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 19 | expanded_dims.insert(dim, (dim, dims[dim])) 20 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 21 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 22 | return torch.cat(tensors, dim=dim) 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, "... (d r) -> ... d r", r=2) 27 | x1, x2 = x.unbind(dim=-1) 28 | x = torch.stack((-x2, x1), dim=-1) 29 | return rearrange(x, "... d r -> ... (d r)") 30 | 31 | 32 | class VisionRotaryEmbedding(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | pt_seq_len, 37 | ft_seq_len=None, 38 | custom_freqs=None, 39 | freqs_for="lang", 40 | theta=10000, 41 | max_freq=10, 42 | num_freqs=1, 43 | ): 44 | super().__init__() 45 | if custom_freqs: 46 | freqs = custom_freqs 47 | elif freqs_for == "lang": 48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 49 | elif freqs_for == "pixel": 50 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 51 | elif freqs_for == "constant": 52 | freqs = torch.ones(num_freqs).float() 53 | else: 54 | raise ValueError(f"unknown modality {freqs_for}") 55 | 56 | if ft_seq_len is None: 57 | ft_seq_len = pt_seq_len 58 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 59 | 60 | freqs_h = torch.einsum("..., f -> ... f", t, freqs) 61 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) 62 | 63 | freqs_w = torch.einsum("..., f -> ... f", t, freqs) 64 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) 65 | 66 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) 67 | 68 | self.register_buffer("freqs_cos", freqs.cos()) 69 | self.register_buffer("freqs_sin", freqs.sin()) 70 | 71 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 72 | 73 | def forward(self, t, start_index=0): 74 | rot_dim = self.freqs_cos.shape[-1] 75 | end_index = start_index + rot_dim 76 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 78 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 79 | 80 | return torch.cat((t_left, t, t_right), dim=-1) 81 | 82 | 83 | class VisionRotaryEmbeddingFast(nn.Module): 84 | def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): 85 | super().__init__() 86 | if custom_freqs: 87 | freqs = custom_freqs 88 | elif freqs_for == "lang": 89 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 90 | elif freqs_for == "pixel": 91 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 92 | elif freqs_for == "constant": 93 | freqs = torch.ones(num_freqs).float() 94 | else: 95 | raise ValueError(f"unknown modality {freqs_for}") 96 | 97 | if ft_seq_len is None: 98 | ft_seq_len = pt_seq_len 99 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 100 | 101 | freqs = torch.einsum("..., f -> ... f", t, freqs) 102 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 103 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 104 | 105 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 106 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 107 | 108 | self.patch_dropout = patch_dropout 109 | 110 | self.register_buffer("freqs_cos", freqs_cos) 111 | self.register_buffer("freqs_sin", freqs_sin) 112 | 113 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 114 | 115 | def forward(self, t, patch_indices_keep=None): 116 | if patch_indices_keep is not None: 117 | batch = t.size()[0] 118 | batch_indices = torch.arange(batch) 119 | batch_indices = batch_indices[..., None] 120 | 121 | freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 122 | freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 123 | 124 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 125 | freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") 126 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 127 | freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") 128 | 129 | return t * freqs_cos + rotate_half(t) * freqs_sin 130 | 131 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 132 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | import timm 14 | from timm.models.layers import Mlp, to_2tuple 15 | 16 | try: 17 | # old timm imports < 0.8.1 18 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 19 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 20 | except ImportError: 21 | # new timm imports >= 0.8.1 22 | from timm.layers import RotAttentionPool2d 23 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 24 | except ImportError: 25 | timm = None 26 | 27 | from .utils import freeze_batch_norm_2d 28 | 29 | 30 | class TimmModel(nn.Module): 31 | """timm model adapter 32 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 33 | """ 34 | 35 | def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): 36 | super().__init__() 37 | if timm is None: 38 | raise RuntimeError("Please `pip install timm` to use timm models.") 39 | 40 | self.image_size = to_2tuple(image_size) 41 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 42 | feat_size = self.trunk.default_cfg.get("pool_size", None) 43 | feature_ndim = 1 if not feat_size else 2 44 | if pool in ("abs_attn", "rot_attn"): 45 | assert feature_ndim == 2 46 | # if attn pooling used, remove both classifier and default pool 47 | self.trunk.reset_classifier(0, global_pool="") 48 | else: 49 | # reset global pool if pool config set, otherwise leave as network default 50 | reset_kwargs = dict(global_pool=pool) if pool else {} 51 | self.trunk.reset_classifier(0, **reset_kwargs) 52 | prev_chs = self.trunk.num_features 53 | 54 | head_layers = OrderedDict() 55 | if pool == "abs_attn": 56 | head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 57 | prev_chs = embed_dim 58 | elif pool == "rot_attn": 59 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 60 | prev_chs = embed_dim 61 | else: 62 | assert proj, "projection layer needed if non-attention pooling is used." 63 | 64 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 65 | if proj == "linear": 66 | head_layers["drop"] = nn.Dropout(drop) 67 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 68 | elif proj == "mlp": 69 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 70 | 71 | self.head = nn.Sequential(head_layers) 72 | 73 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 74 | """lock modules 75 | Args: 76 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 77 | """ 78 | if not unlocked_groups: 79 | # lock full model 80 | for param in self.trunk.parameters(): 81 | param.requires_grad = False 82 | if freeze_bn_stats: 83 | freeze_batch_norm_2d(self.trunk) 84 | else: 85 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 86 | try: 87 | # FIXME import here until API stable and in an official release 88 | from timm.models.helpers import group_parameters, group_modules 89 | except ImportError: 90 | raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") 91 | matcher = self.trunk.group_matcher() 92 | gparams = group_parameters(self.trunk, matcher) 93 | max_layer_id = max(gparams.keys()) 94 | max_layer_id = max_layer_id - unlocked_groups 95 | for group_idx in range(max_layer_id + 1): 96 | group = gparams[group_idx] 97 | for param in group: 98 | self.trunk.get_parameter(param).requires_grad = False 99 | if freeze_bn_stats: 100 | gmodules = group_modules(self.trunk, matcher, reverse=True) 101 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 102 | freeze_batch_norm_2d(self.trunk, gmodules) 103 | 104 | @torch.jit.ignore 105 | def set_grad_checkpointing(self, enable=True): 106 | try: 107 | self.trunk.set_grad_checkpointing(enable) 108 | except Exception as e: 109 | logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") 110 | 111 | def forward(self, x): 112 | x = self.trunk(x) 113 | x = self.head(x) 114 | return x 115 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == "min" else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert("RGB") 40 | 41 | 42 | # class CatGen(nn.Module): 43 | # def __init__(self, num=4): 44 | # self.num = num 45 | # def mixgen_batch(image, text): 46 | # batch_size = image.shape[0] 47 | # index = np.random.permutation(batch_size) 48 | 49 | # cat_images = [] 50 | # for i in range(batch_size): 51 | # # image mixup 52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 53 | # # text concat 54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 55 | # text = torch.stack(text) 56 | # return image, text 57 | 58 | 59 | def image_transform( 60 | image_size: int, 61 | is_train: bool, 62 | mean: Optional[Tuple[float, ...]] = None, 63 | std: Optional[Tuple[float, ...]] = None, 64 | resize_longest_max: bool = False, 65 | fill_color: int = 0, 66 | ): 67 | mean = mean or OPENAI_DATASET_MEAN 68 | if not isinstance(mean, (list, tuple)): 69 | mean = (mean,) * 3 70 | 71 | std = std or OPENAI_DATASET_STD 72 | if not isinstance(std, (list, tuple)): 73 | std = (std,) * 3 74 | 75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 77 | image_size = image_size[0] 78 | 79 | normalize = Normalize(mean=mean, std=std) 80 | if is_train: 81 | return Compose( 82 | [ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ] 88 | ) 89 | else: 90 | if resize_longest_max: 91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 92 | else: 93 | transforms = [ 94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 95 | CenterCrop(image_size), 96 | ] 97 | transforms.extend( 98 | [ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ] 103 | ) 104 | return Compose(transforms) 105 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py: -------------------------------------------------------------------------------- 1 | # Based on EVA, BEIT, timm and DeiT code bases 2 | # https://github.com/baaivision/EVA 3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | # not tested yet 9 | import math 10 | from transformers import CLIPImageProcessor 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as checkpoint 16 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 17 | from .eva_clip import create_model_and_transforms, get_model_config 18 | import torch 19 | import torchvision 20 | import time 21 | 22 | from llava.utils import rank0_print 23 | 24 | 25 | class EvaViTWrapper(nn.Module): 26 | def __init__(self, vision_tower, args, delay_load=False): 27 | super().__init__() 28 | 29 | self.is_loaded = False 30 | self.vision_tower_name = vision_tower 31 | self.pretrained = args.vision_tower_pretrained 32 | self.args = args 33 | 34 | self.select_layer = args.mm_vision_select_layer 35 | if self.select_layer < -1: 36 | self.select_layer += 1 37 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 38 | 39 | self.model_config = get_model_config(self.vision_tower_name) 40 | 41 | if not delay_load: 42 | rank0_print(f"Loading vision tower: {vision_tower}") 43 | self.load_model() 44 | elif getattr(args, "unfreeze_mm_vision_tower", False): 45 | # TODO: better detector is needed. 46 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 47 | self.load_model() 48 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 49 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 50 | self.load_model() 51 | 52 | def load_model(self): 53 | rank0_print(f"Loading: {self.vision_tower_name}") 54 | rank0_print(f"Pretrained: {self.pretrained}") 55 | time_start = time.time() 56 | model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") 57 | time_end = time.time() 58 | rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") 59 | self.device = next(model.parameters()).device 60 | self.dtype = next(model.parameters()).dtype 61 | if self.device.type != "meta": 62 | model = model.to("cuda") 63 | self.vision_tower = model.visual 64 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 65 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 66 | self.resize_transform_size = resize_transform.size 67 | self.image_processor = CLIPImageProcessor.from_pretrained( 68 | "openai/clip-vit-large-patch14", 69 | crop_size=resize_transform.size, 70 | size={"shortest_edge": resize_transform.size}, 71 | image_mean=list(normalize_transform.mean), 72 | image_std=list(normalize_transform.std), 73 | ) 74 | rank0_print(f"Loaded image processor: {self.image_processor}") 75 | self.vision_tower.requires_grad_(False) 76 | self.is_loaded = True 77 | 78 | def feature_select(self, image_features): 79 | select_feature_type = self.select_feature 80 | 81 | # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 82 | # select_every_k_layer = len(image_features) // 4 83 | # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) 84 | # select_feature_type = select_feature_type.replace("slicefour_", "") 85 | # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 86 | # select_layers = [-1, -4, -7, -10, 6] 87 | # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) 88 | # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 89 | # else: 90 | # image_features = image_features[self.select_layer] 91 | 92 | if select_feature_type == "patch": 93 | image_features = image_features[:, 1:] 94 | elif select_feature_type == "cls_patch": 95 | image_features = image_features 96 | else: 97 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 98 | return image_features 99 | 100 | def train(self, mode=True): 101 | self.training = mode 102 | 103 | if self.is_loaded: 104 | self.vision_tower.eval() 105 | 106 | def forward(self, images): 107 | if type(images) is list: 108 | image_features = [] 109 | for image in images: 110 | image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) 111 | image_features = self.feature_select(image_features).to(self.dtype) 112 | image_features.append(image_features) 113 | else: 114 | image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) 115 | image_features = self.feature_select(image_features).to(self.dtype) 116 | 117 | return image_features 118 | 119 | @property 120 | def dummy_feature(self): 121 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 122 | 123 | @property 124 | def hidden_size(self): 125 | return self.model_config["vision_cfg"]["width"] 126 | 127 | @property 128 | def num_patches(self): 129 | return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 130 | 131 | @property 132 | def num_patches_per_side(self): 133 | return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] 134 | 135 | @property 136 | def config(self): 137 | return self.model_config 138 | 139 | @property 140 | def image_size(self): 141 | return self.model_config["vision_cfg"]["image_size"] 142 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import EVAEncoderWrapper 6 | from .factory import list_models, add_model_config, get_model_config 7 | 8 | from llava.utils import rank0_print 9 | 10 | 11 | class EvaClipVisionTower(nn.Module): 12 | def __init__(self, vision_tower, args, delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | self.vision_tower_name = vision_tower 17 | self.vision_tower_pretrained = args.vision_tower_pretrained 18 | self.config = get_model_config(vision_tower) 19 | 20 | if not delay_load: 21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") 22 | self.load_model() 23 | elif getattr(args, "unfreeze_mm_vision_tower", False): 24 | # TODO: better detector is needed. 25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 26 | self.load_model() 27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 29 | self.load_model() 30 | else: 31 | self.cfg_only = self.config 32 | 33 | def load_model(self, device_map=None): 34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}") 35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) 36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) 37 | rank0_print(f"Loaded image processor: {self.image_processor}") 38 | self.vision_tower.requires_grad_(False) 39 | self.is_loaded = True 40 | 41 | def forward(self, images): 42 | if type(images) is list: 43 | image_features = [] 44 | for image in images: 45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 49 | 50 | return image_features 51 | 52 | @property 53 | def dtype(self): 54 | return self.vision_tower.dtype 55 | 56 | @property 57 | def device(self): 58 | return self.vision_tower.device 59 | 60 | @property 61 | def hidden_size(self): 62 | return self.config["vision_cfg"]["width"] 63 | 64 | @property 65 | def num_patches(self): 66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 67 | 68 | @property 69 | def num_patches_per_side(self): 70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] 71 | 72 | @property 73 | def image_size(self): 74 | return self.config["vision_cfg"]["image_size"] 75 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | """ 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {"height": self.image_size, "width": self.image_size} 69 | 70 | @property 71 | def size(self): 72 | return {"shortest_edge": self.image_size} 73 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 13 | 14 | 15 | def _natural_key(string_): 16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 17 | 18 | 19 | def _rescan_model_configs(): 20 | global _MODEL_CONFIGS 21 | 22 | config_ext = (".json",) 23 | config_files = [] 24 | for config_path in _MODEL_CONFIG_PATHS: 25 | if config_path.is_file() and config_path.suffix in config_ext: 26 | config_files.append(config_path) 27 | elif config_path.is_dir(): 28 | for ext in config_ext: 29 | config_files.extend(config_path.glob(f"*{ext}")) 30 | 31 | for cf in config_files: 32 | with open(cf, "r", encoding="utf8") as f: 33 | model_cfg = json.load(f) 34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): 35 | _MODEL_CONFIGS[cf.stem] = model_cfg 36 | 37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 38 | 39 | 40 | _rescan_model_configs() # initial populate of model config registry 41 | 42 | 43 | def list_models(): 44 | """enumerate available model architectures based on config files""" 45 | return list(_MODEL_CONFIGS.keys()) 46 | 47 | 48 | def add_model_config(path): 49 | """add model config path or file and update registry""" 50 | if not isinstance(path, Path): 51 | path = Path(path) 52 | _MODEL_CONFIG_PATHS.append(path) 53 | _rescan_model_configs() 54 | 55 | 56 | def get_model_config(model_name): 57 | if model_name in _MODEL_CONFIGS: 58 | return deepcopy(_MODEL_CONFIGS[model_name]) 59 | else: 60 | return None 61 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/hf_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor 5 | from llava.utils import rank0_print 6 | 7 | 8 | class HFVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1) 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | try: 25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) 26 | except Exception as e: 27 | if "448" in self.vision_tower_name: 28 | image_size = 448 29 | # use image processor with conig 30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) 31 | else: 32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | rank0_print(f"Loaded image processor: {self.image_processor}") 34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") 35 | self.device = self.vision_tower.device 36 | self.dtype = self.vision_tower.dtype 37 | self.config = self.vision_tower.config 38 | 39 | if hasattr(self.vision_tower, "vision_model"): 40 | self.vision_tower = self.vision_tower.vision_model 41 | self.vision_tower.requires_grad_(False) 42 | # self.vision_tower.eval() 43 | self.is_loaded = True 44 | 45 | def feature_select(self, image_forward_outs): 46 | select_feature_type = self.select_feature 47 | 48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 51 | select_feature_type = select_feature_type.replace("slicefour_", "") 52 | else: 53 | image_features = image_forward_outs.hidden_states[self.select_layer] 54 | 55 | if select_feature_type == "patch": 56 | image_features = image_features[:, 1:] 57 | elif select_feature_type == "cls_patch": 58 | image_features = image_features 59 | else: 60 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 61 | return image_features 62 | 63 | def forward(self, images): 64 | if type(images) is list: 65 | image_features = [] 66 | for image in images: 67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 68 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 69 | image_features.append(image_feature) 70 | else: 71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 72 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 73 | 74 | return image_features 75 | 76 | @property 77 | def dummy_feature(self): 78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 79 | 80 | # @property 81 | # def dtype(self): 82 | # return self.vision_tower.dtype 83 | 84 | # @property 85 | # def device(self): 86 | # return self.vision_tower.device 87 | 88 | @property 89 | def hidden_size(self): 90 | try: 91 | _hidden_size = self.config.hidden_size 92 | except: 93 | _hidden_size = self.config.vision_config.hidden_size 94 | if "slicefour" in self.select_feature: 95 | _hidden_size *= 4 96 | return _hidden_size 97 | 98 | @property 99 | def num_patches(self): 100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 101 | if "cls_patch" in self.select_feature: 102 | _num_patches += 1 103 | return _num_patches 104 | 105 | @property 106 | def num_patches_per_side(self): 107 | return self.config.image_size // self.config.patch_size 108 | 109 | @property 110 | def image_size(self): 111 | return self.config.image_size 112 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/rm_process/download_pairwise_data.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import json 3 | input_file = '/data/Alignment/llava_7b_v1_preference.json' 4 | output_file = 'tmp/llava_7b_v1_preference.jsonl' 5 | 6 | # 打开并读取文件 7 | with open(input_file, 'r') as f_in, open(output_file, 'w') as f_out: 8 | datas = json.load(f_in) 9 | for data in datas: 10 | 11 | human_prompt = next(item['value'] for item in reversed(data['conversations']) if item['from'] == 'human') 12 | 13 | # Get the chosen and rejected outputs based on preference 14 | chosen = data['output_1'] if data['preference'] == 1 else data['output_2'] 15 | rejected = data['output_2'] if data['preference'] == 1 else data['output_1'] 16 | 17 | # Create the modified data structure 18 | modified_data = { 19 | "prompt": human_prompt, 20 | "chosen": chosen["value"], 21 | "rejected": rejected["value"], 22 | "has_image": True, 23 | "image": data['image'], 24 | "id": data['id'] 25 | } 26 | 27 | # Write the modified data to the output file 28 | f_out.write(json.dumps(modified_data) + '\n') -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith("http") or image_file.startswith("https"): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert("RGB") 22 | else: 23 | image = Image.open(image_file).convert("RGB") 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ("user", "assistant") 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda() 56 | 57 | while True: 58 | try: 59 | inp = input(f"{roles[0]}: ") 60 | except EOFError: 61 | inp = "" 62 | if not inp: 63 | print("exit...") 64 | break 65 | 66 | print(f"{roles[1]}: ", end="") 67 | 68 | if image is not None: 69 | # first message 70 | if model.config.mm_use_im_start_end: 71 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp 72 | else: 73 | inp = DEFAULT_IMAGE_TOKEN + "\n" + inp 74 | conv.append_message(conv.roles[0], inp) 75 | image = None 76 | else: 77 | # later messages 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 83 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 84 | keywords = [stop_str] 85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 87 | 88 | with torch.inference_mode(): 89 | output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) 90 | 91 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip() 92 | conv.messages[-1][-1] = outputs 93 | 94 | if args.debug: 95 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 101 | parser.add_argument("--model-base", type=str, default=None) 102 | parser.add_argument("--image-file", type=str, required=True) 103 | parser.add_argument("--num-gpus", type=int, default=1) 104 | parser.add_argument("--conv-mode", type=str, default=None) 105 | parser.add_argument("--temperature", type=float, default=0.2) 106 | parser.add_argument("--max-new-tokens", type=int, default=512) 107 | parser.add_argument("--load-8bit", action="store_true") 108 | parser.add_argument("--load-4bit", action="store_true") 109 | parser.add_argument("--debug", action="store_true") 110 | args = parser.parse_args() 111 | main(args) 112 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-YuanQi/MM-RLHF/8876a5953edbc388f1daf5f9c8d2f8adb6623eb9/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) 21 | worker_addr = ret.json()["address"] 22 | print(f"worker_addr: {worker_addr}") 23 | 24 | if worker_addr == "": 25 | return 26 | 27 | conv = default_conversation.copy() 28 | conv.append_message(conv.roles[0], args.message) 29 | prompt = conv.get_prompt() 30 | 31 | headers = {"User-Agent": "LLaVA Client"} 32 | pload = { 33 | "model": args.model_name, 34 | "prompt": prompt, 35 | "max_new_tokens": args.max_new_tokens, 36 | "temperature": 0.7, 37 | "stop": conv.sep, 38 | } 39 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) 40 | 41 | print(prompt.replace(conv.sep, "\n"), end="") 42 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 43 | if chunk: 44 | data = json.loads(chunk.decode("utf-8")) 45 | output = data["text"].split(conv.sep)[-1] 46 | print(output, end="\r") 47 | print("") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 53 | parser.add_argument("--worker-address", type=str) 54 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 55 | parser.add_argument("--max-new-tokens", type=int, default=32) 56 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") 57 | args = parser.parse_args() 58 | 59 | main() 60 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | padding_mask: Optional[torch.Tensor] = None, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | if output_attentions: 27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 28 | 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) 34 | 35 | kv_seq_len = key_states.shape[-2] 36 | if past_key_value is not None: 37 | kv_seq_len += past_key_value[0].shape[-2] 38 | 39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 41 | 42 | if past_key_value is not None: 43 | # reuse k, v 44 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 45 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 46 | 47 | past_key_value = (key_states, value_states) if use_cache else None 48 | 49 | # repeat k/v heads if n_kv_heads < n_heads 50 | key_states = repeat_kv(key_states, self.num_key_value_groups) 51 | value_states = repeat_kv(value_states, self.num_key_value_groups) 52 | 53 | # Transform the data into the format required by flash attention 54 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 56 | key_padding_mask = attention_mask 57 | 58 | if key_padding_mask is None: 59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 61 | max_s = q_len 62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 63 | output = output.view(bsz, q_len, -1) 64 | else: 65 | qkv = qkv.reshape(bsz, q_len, -1) 66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 70 | output = pad_input(output_unpad, indices, bsz, q_len) 71 | 72 | return self.o_proj(output), None, past_key_value 73 | 74 | 75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 76 | # requires the attention mask to be the same as the key_padding_mask 77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 78 | # [bsz, seq_len] 79 | return attention_mask 80 | 81 | 82 | def replace_llama_attn_with_flash_attn(): 83 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 84 | if cuda_major < 8: 85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") 86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 88 | -------------------------------------------------------------------------------- /llava/train/llava_trainer_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from llava.train.llava_trainer import LLaVATrainer 5 | 6 | 7 | class LLaVAEvalTrainer(LLaVATrainer): 8 | def evaluate(self, evaluate_args): 9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ 10 | --model {evaluate_args.model} \ 11 | --model_args {evaluate_args.model_args} \ 12 | --tasks {evaluate_args.task_names} \ 13 | --batch_size {evaluate_args.batch_size} \ 14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \ 15 | --output_path {evaluate_args.output_path}" 16 | if evaluate_args.limit: 17 | cmd += f" --limit {evaluate_args.limit}" 18 | if evaluate_args.num_fewshot: 19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}" 20 | if evaluate_args.gen_kwargs != "": 21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" 22 | if evaluate_args.log_samples: 23 | cmd += f" --log_samples" 24 | else: 25 | assert False, "Please log samples so that the result can be parsed" 26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True) 27 | try: 28 | result_file_index_start = results.stdout.index("Saved samples to ") 29 | result_file_index_end = results.stdout.index(f".json") 30 | result_file_index_start += len("Saved samples to ") 31 | file = results.stdout[result_file_index_start:result_file_index_end] 32 | except: 33 | result_file_index_start = results.stderr.index("Saved samples to ") 34 | result_file_index_end = results.stderr.index(f".json") 35 | result_file_index_start += len("Saved samples to ") 36 | file = results.stderr[result_file_index_start:result_file_index_end] 37 | file = file.split("/")[:-1] 38 | file = "/".join(file) + "/results.json" 39 | with open(file, "r") as f: 40 | lmms_eval_results = json.load(f) 41 | result_dict = {} 42 | tasks_list = evaluate_args.task_names.split(",") 43 | for task in tasks_list: 44 | task_results = lmms_eval_results["results"][task] 45 | for k, v in task_results.items(): 46 | if k != "alias" and "stderr" not in k: 47 | metric = k.split(",")[0] 48 | result_dict[f"{task}_{metric}"] = v 49 | return result_dict 50 | 51 | """def evaluate(self, evaluate_args): 52 | initialize_tasks() 53 | tasks_list = evaluate_args.task_names.split(",") 54 | result_dict = {} 55 | results = evaluator.simple_evaluate( 56 | model=evaluate_args.model, 57 | model_args=evaluate_args.model_args, 58 | tasks=tasks_list, 59 | num_fewshot=evaluate_args.num_fewshot, 60 | batch_size=evaluate_args.batch_size, 61 | device=evaluate_args.device, 62 | limit=evaluate_args.limit, 63 | check_integrity=evaluate_args.check_integrity, 64 | show_task_to_terminal=evaluate_args.show_task_to_terminal, 65 | log_samples=evaluate_args.log_samples, 66 | gen_kwargs=evaluate_args.gen_kwargs, 67 | cli_args=evaluate_args, 68 | ) 69 | for task in tasks_list: 70 | task_results = results["results"][task] 71 | for k,v in task_results.items(): 72 | if k != "alias" and "stderr" not in k: 73 | metric = k.split(",")[0] 74 | result_dict[f"{task}_{metric}"] = v 75 | 76 | return result_dict""" 77 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train() 5 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 4 | from llava.conversation import conv_templates, SeparatorStyle 5 | from llava.model.builder import load_pretrained_model 6 | from llava.utils import disable_torch_init 7 | from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria 8 | from transformers.generation.streamers import TextIteratorStreamer 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from io import BytesIO 14 | 15 | from cog import BasePredictor, Input, Path, ConcatenateIterator 16 | import time 17 | import subprocess 18 | from threading import Thread 19 | 20 | import os 21 | os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" 22 | 23 | # url for the weights mirror 24 | REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default" 25 | # files to download from the weights mirrors 26 | weights = [ 27 | { 28 | "dest": "liuhaotian/llava-v1.5-13b", 29 | # git commit hash from huggingface 30 | "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8", 31 | "files": [ 32 | "config.json", 33 | "generation_config.json", 34 | "pytorch_model-00001-of-00003.bin", 35 | "pytorch_model-00002-of-00003.bin", 36 | "pytorch_model-00003-of-00003.bin", 37 | "pytorch_model.bin.index.json", 38 | "special_tokens_map.json", 39 | "tokenizer.model", 40 | "tokenizer_config.json", 41 | ], 42 | }, 43 | { 44 | "dest": "openai/clip-vit-large-patch14-336", 45 | "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", 46 | "files": ["config.json", "preprocessor_config.json", "pytorch_model.bin"], 47 | }, 48 | ] 49 | 50 | 51 | def download_json(url: str, dest: Path): 52 | res = requests.get(url, allow_redirects=True) 53 | if res.status_code == 200 and res.content: 54 | with dest.open("wb") as f: 55 | f.write(res.content) 56 | else: 57 | print(f"Failed to download {url}. Status code: {res.status_code}") 58 | 59 | def download_weights(baseurl: str, basedest: str, files: list[str]): 60 | basedest = Path(basedest) 61 | start = time.time() 62 | print("downloading to: ", basedest) 63 | basedest.mkdir(parents=True, exist_ok=True) 64 | for f in files: 65 | dest = basedest / f 66 | url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f) 67 | if not dest.exists(): 68 | print("downloading url: ", url) 69 | if dest.suffix == ".json": 70 | download_json(url, dest) 71 | else: 72 | subprocess.check_call(["pget", url, str(dest)], close_fds=False) 73 | print("downloading took: ", time.time() - start) 74 | 75 | class Predictor(BasePredictor): 76 | def setup(self) -> None: 77 | """Load the model into memory to make running multiple predictions efficient""" 78 | for weight in weights: 79 | download_weights(weight["src"], weight["dest"], weight["files"]) 80 | disable_torch_init() 81 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) 82 | 83 | def predict( 84 | self, 85 | image: Path = Input(description="Input image"), 86 | prompt: str = Input(description="Prompt to use for text generation"), 87 | top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0), 88 | temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0), 89 | max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0), 90 | ) -> ConcatenateIterator[str]: 91 | """Run a single prediction on the model""" 92 | 93 | conv_mode = "llava_v1" 94 | conv = conv_templates[conv_mode].copy() 95 | 96 | image_data = load_image(str(image)) 97 | image_tensor = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"].half().cuda() 98 | 99 | # loop start 100 | 101 | # just one turn, always prepend image token 102 | inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt 103 | conv.append_message(conv.roles[0], inp) 104 | 105 | conv.append_message(conv.roles[1], None) 106 | prompt = conv.get_prompt() 107 | 108 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 109 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 110 | keywords = [stop_str] 111 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) 112 | streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0) 113 | 114 | with torch.inference_mode(): 115 | thread = Thread( 116 | target=self.model.generate, 117 | kwargs=dict(inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]), 118 | ) 119 | thread.start() 120 | # workaround: second-to-last token is always " " 121 | # but we want to keep it if it's not the second-to-last token 122 | prepend_space = False 123 | for new_text in streamer: 124 | if new_text == " ": 125 | prepend_space = True 126 | continue 127 | if new_text.endswith(stop_str): 128 | new_text = new_text[: -len(stop_str)].strip() 129 | prepend_space = False 130 | elif prepend_space: 131 | new_text = " " + new_text 132 | prepend_space = False 133 | if len(new_text): 134 | yield new_text 135 | if prepend_space: 136 | yield " " 137 | thread.join() 138 | 139 | 140 | def load_image(image_file): 141 | if image_file.startswith("http") or image_file.startswith("https"): 142 | response = requests.get(image_file) 143 | image = Image.open(BytesIO(response.content)).convert("RGB") 144 | else: 145 | image = Image.open(image_file).convert("RGB") 146 | return image 147 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "llava" 10 | version = "1.7.0.dev0" 11 | description = "LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | 19 | [project.optional-dependencies] 20 | standalone = [ 21 | "shortuuid", 22 | "httpx==0.24.0", 23 | "einops", 24 | "ftfy", 25 | ] 26 | 27 | 28 | train = [ 29 | "llava[standalone]", 30 | "numpy==1.26.1", 31 | "open_clip_torch", 32 | "fastapi", 33 | "markdown2[all]", 34 | "numpy", 35 | "requests", 36 | "sentencepiece", 37 | "torch==2.1.2", 38 | "torchvision==0.16.2", 39 | "uvicorn", 40 | "wandb", 41 | "deepspeed==0.14.4", 42 | "peft==0.4.0", 43 | "accelerate>=0.29.1", 44 | "tokenizers~=0.15.2", 45 | "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4", 46 | "bitsandbytes==0.41.0", 47 | "scikit-learn==1.2.2", 48 | "sentencepiece~=0.1.99", 49 | "einops==0.6.1", 50 | "einops-exts==0.0.4", 51 | "gradio_client==0.2.9", 52 | "urllib3<=2.0.0", 53 | "datasets==2.16.1", 54 | "pydantic==1.10.8", 55 | "timm", 56 | "hf_transfer", 57 | "opencv-python", 58 | "av", 59 | "decord", 60 | "tyro", 61 | "scipy", 62 | ] 63 | 64 | [project.urls] 65 | "Homepage" = "https://llava-vl.github.io" 66 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 67 | 68 | [tool.setuptools.packages.find] 69 | include = ["llava*", "trl*"] 70 | exclude = [ 71 | "assets*", 72 | "benchmark*", 73 | "docs", 74 | "dist*", 75 | "playground*", 76 | "scripts*", 77 | "tests*", 78 | "checkpoints*", 79 | "project_checkpoints*", 80 | "debug_checkpoints*", 81 | "mlx_configs*", 82 | "wandb*", 83 | "notebooks*", 84 | ] 85 | 86 | [tool.wheel] 87 | exclude = [ 88 | "assets*", 89 | "benchmark*", 90 | "docs", 91 | "dist*", 92 | "playground*", 93 | "scripts*", 94 | "tests*", 95 | "checkpoints*", 96 | "project_checkpoints*", 97 | "debug_checkpoints*", 98 | "mlx_configs*", 99 | "wandb*", 100 | "notebooks*", 101 | ] 102 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Babel==2.14.0 2 | DataProperty==1.0.1 3 | Deprecated==1.2.14 4 | GitPython==3.1.43 5 | Jinja2==3.1.3 6 | Levenshtein==0.25.1 7 | MarkupSafe==2.1.5 8 | PyJWT==2.8.0 9 | PyYAML==6.0.1 10 | Pygments==2.17.2 11 | QtPy==2.4.1 12 | Send2Trash==1.8.3 13 | absl-py==2.1.0 14 | accelerate==0.29.3 15 | aiofiles==22.1.0 16 | aiohttp==3.9.5 17 | aiosignal==1.3.1 18 | aiosqlite==0.20.0 19 | altair==5.3.0 20 | anyio==4.3.0 21 | appdirs==1.4.4 22 | argon2-cffi-bindings==21.2.0 23 | argon2-cffi==23.1.0 24 | arrow==1.3.0 25 | asttokens==2.4.1 26 | async-timeout==4.0.3 27 | attrs==23.1.0 28 | beautifulsoup4==4.12.3 29 | bidict==0.23.1 30 | bitsandbytes==0.41.0 31 | black==24.1.0 32 | bleach==6.1.0 33 | certifi==2024.2.2 34 | cffi==1.16.0 35 | cfgv==3.4.0 36 | chardet==5.2.0 37 | charset-normalizer==3.3.2 38 | click==8.1.7 39 | colorama==0.4.6 40 | comm==0.2.2 41 | contourpy==1.2.1 42 | crcmod==1.7 43 | cryptography==38.0.4 44 | cycler==0.12.1 45 | datasets==2.16.1 46 | debugpy==1.8.1 47 | decorator==5.1.1 48 | decord==0.6.0 49 | deepspeed==0.12.2 50 | defusedxml==0.7.1 51 | dill==0.3.7 52 | distlib==0.3.8 53 | distro==1.9.0 54 | dnspython==2.6.1 55 | docker-pycreds==0.4.0 56 | docstring_parser==0.16 57 | einops-exts==0.0.4 58 | einops==0.6.1 59 | entrypoints==0.4 60 | et-xmlfile==1.1.0 61 | eval_type_backport==0.2.0 62 | evaluate==0.4.1 63 | exceptiongroup==1.2.1 64 | executing==2.0.1 65 | fastapi==0.110.2 66 | fastjsonschema==2.19.1 67 | ffmpy==0.3.2 68 | filelock==3.13.4 69 | flash-attn==2.5.7 70 | fonttools==4.51.0 71 | fqdn==1.5.1 72 | frozenlist==1.4.1 73 | fsspec==2023.10.0 74 | ftfy==6.2.0 75 | gitdb==4.0.11 76 | gradio==3.35.2 77 | gradio_client==0.2.9 78 | grpcio==1.62.2 79 | h11==0.14.0 80 | hf_transfer==0.1.6 81 | hjson==3.1.0 82 | httpcore==0.17.3 83 | httpx==0.24.0 84 | huggingface-hub==0.22.2 85 | identify==2.5.36 86 | idna==3.7 87 | importlib_metadata==7.1.0 88 | importlib_resources==6.4.0 89 | iniconfig==2.0.0 90 | ipaddress==1.0.23 91 | ipykernel==6.29.4 92 | ipython-genutils==0.2.0 93 | ipython==8.18.1 94 | ipywidgets==8.1.2 95 | isoduration==20.11.0 96 | jedi==0.19.1 97 | joblib==1.4.0 98 | json5==0.9.25 99 | jsonlines==4.0.0 100 | jsonpointer==2.4 101 | jsonschema-specifications==2023.12.1 102 | jsonschema==4.21.1 103 | 104 | kiwisolver==1.4.5 105 | linkify-it-py==2.0.3 106 | llava==1.7.0.dev0 107 | llava==1.7.0.dev0 108 | lmms_eval==0.1.1 109 | lxml==5.2.1 110 | markdown-it-py==2.2.0 111 | markdown2==2.4.13 112 | matplotlib-inline==0.1.7 113 | matplotlib==3.8.4 114 | mbstrdecoder==1.1.3 115 | mdit-py-plugins==0.3.3 116 | mdurl==0.1.2 117 | mistune==3.0.2 118 | mpmath==1.3.0 119 | msgpack==1.0.8 120 | multidict==6.0.5 121 | multiprocess==0.70.15 122 | mypy-extensions==1.0.0 123 | nbclassic==1.0.0 124 | nbclient==0.10.0 125 | nbconvert==7.16.3 126 | nbformat==5.10.4 127 | nest-asyncio==1.6.0 128 | networkx==3.2.1 129 | ninja==1.11.1.1 130 | nltk==3.8.1 131 | nodeenv==1.8.0 132 | notebook==6.5.6 133 | notebook_shim==0.2.4 134 | numexpr==2.10.0 135 | numpy==1.26.4 136 | nvidia-cublas-cu12==12.1.3.1 137 | nvidia-cuda-cupti-cu12==12.1.105 138 | nvidia-cuda-nvrtc-cu12==12.1.105 139 | nvidia-cuda-runtime-cu12==12.1.105 140 | nvidia-cudnn-cu12==8.9.2.26 141 | nvidia-cufft-cu12==11.0.2.54 142 | nvidia-curand-cu12==10.3.2.106 143 | nvidia-cusolver-cu12==11.4.5.107 144 | nvidia-cusparse-cu12==12.1.0.106 145 | nvidia-nccl-cu12==2.18.1 146 | nvidia-nvjitlink-cu12==12.4.127 147 | nvidia-nvtx-cu12==12.1.105 148 | open-clip-torch==2.24.0 149 | openai==1.23.6 150 | opencv-python-headless==4.9.0.80 151 | openpyxl==3.1.2 152 | orjson==3.10.1 153 | overrides==7.7.0 154 | packaging==24.0 155 | pandas==2.2.2 156 | pandocfilters==1.5.1 157 | parso==0.8.4 158 | pathlib2==2.3.7.post1 159 | pathspec==0.12.1 160 | pathtools==0.1.2 161 | pathvalidate==3.2.0 162 | peft==0.4.0 163 | pexpect==4.8.0 164 | pillow==10.3.0 165 | platformdirs==4.2.1 166 | pluggy==1.5.0 167 | ply==3.11 168 | portalocker==2.8.2 169 | pre-commit==3.7.0 170 | prometheus_client==0.20.0 171 | promise==2.3 172 | prompt-toolkit==3.0.43 173 | protobuf==3.20.3 174 | psutil==5.9.8 175 | ptyprocess==0.7.0 176 | pure-eval==0.2.2 177 | py-cpuinfo==9.0.0 178 | py-spy==0.3.14 179 | py==1.11.0 180 | pyOpenSSL==22.1.0 181 | pyarrow-hotfix==0.6 182 | pyarrow==16.0.0 183 | pybind11==2.12.0 184 | pycocoevalcap==1.2 185 | pycocotools==2.0.7 186 | pycparser==2.22 187 | pycryptodomex==3.20.0 188 | pydantic==1.10.8 189 | pydub==0.25.1 190 | pynvml==11.5.0 191 | pyparsing==3.1.2 192 | pytablewriter==1.2.0 193 | pytest==6.2.5 194 | python-consul==1.1.0 195 | python-dateutil==2.9.0.post0 196 | python-engineio==4.9.0 197 | python-etcd==0.4.5 198 | python-json-logger==2.0.7 199 | python-multipart==0.0.9 200 | python-socketio==5.11.2 201 | pytz==2024.1 202 | pyzmq==24.0.1 203 | qtconsole==5.5.1 204 | rapidfuzz==3.8.1 205 | referencing==0.35.0 206 | regex==2024.4.16 207 | requests==2.31.0 208 | responses==0.18.0 209 | rfc3339-validator==0.1.4 210 | rfc3986-validator==0.1.1 211 | rich==13.7.1 212 | rouge-score==0.1.2 213 | rpds-py==0.18.0 214 | sacrebleu==2.4.2 215 | safetensors==0.4.3 216 | schedule==1.2.1 217 | scikit-learn==1.2.2 218 | scipy==1.13.0 219 | semantic-version==2.10.0 220 | sentencepiece==0.1.99 221 | sentry-sdk==2.0.0 222 | setproctitle==1.3.3 223 | setuptools==68.2.2 224 | shortuuid==1.0.13 225 | shtab==1.7.1 226 | simple-websocket==1.0.0 227 | six==1.16.0 228 | smmap==5.0.1 229 | sniffio==1.3.1 230 | soupsieve==2.5 231 | sqlitedict==2.1.0 232 | stack-data==0.6.3 233 | starlette==0.37.2 234 | svgwrite==1.4.3 235 | sympy==1.12 236 | tabledata==1.3.3 237 | tabulate==0.9.0 238 | tcolorpy==0.1.4 239 | tenacity==8.2.3 240 | terminado==0.18.1 241 | threadpoolctl==3.4.0 242 | thriftpy2==0.4.20 243 | tiktoken==0.6.0 244 | timm==0.9.16 245 | tinycss2==1.3.0 246 | tokenizers==0.15.2 247 | toml==0.10.2 248 | tomli==2.0.1 249 | toolz==0.12.1 250 | torch==2.1.2 251 | torchvision==0.16.2 252 | tornado==6.4 253 | tox==3.28.0 254 | tqdm-multiprocess==0.0.11 255 | tqdm==4.66.2 256 | traitlets==5.14.3 257 | transformers-stream-generator==0.0.5 258 | triton==2.1.0 259 | typepy==1.3.2 260 | types-python-dateutil==2.9.0.20240316 261 | typing_extensions==4.11.0 262 | tyro==0.8.3 263 | tzdata==2024.1 264 | uc-micro-py==1.0.3 265 | uri-template==1.3.0 266 | urllib3==2.2.1 267 | uvicorn==0.29.0 268 | virtualenv==20.26.0 269 | wandb==0.16.5 270 | watchdog==4.0.0 271 | wavedrom==2.0.3.post3 272 | wcwidth==0.2.13 273 | webcolors==1.13 274 | webencodings==0.5.1 275 | websocket-client==1.8.0 276 | websockets==12.0 277 | wheel==0.41.2 278 | widgetsnbextension==4.0.10 279 | wrapt==1.16.0 280 | wsproto==1.2.0 281 | xxhash==3.4.1 282 | y-py==0.6.2 283 | yarl==1.9.4 284 | ypy-websocket==0.8.4 285 | zipp==3.18.1 286 | zstandard==0.22.0 -------------------------------------------------------------------------------- /scripts/convert_to_swift.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | # Input file path 5 | input_file = '{your_path_to_mmrlhf}/MM-RLHF/dpo_pairs.jsonl' 6 | 7 | # Output file paths 8 | video_output_file = 'tmp/mmrlhf_v1_video.jsonl' 9 | image_output_file = 'tmp/mmrlhf_v1_image.jsonl' 10 | 11 | # Directories for image and video 12 | image_dir = '{your_path_to_mmrlhf}/MM-RLHF/' 13 | video_dir = '{your_path_to_mmrlhf}/MM-RLHF/' 14 | 15 | # Open input file for reading 16 | with open(input_file, 'r') as infile: 17 | # Open output files for writing 18 | with open(video_output_file, 'w') as video_file, open(image_output_file, 'w') as image_file: 19 | for line in infile: 20 | item = json.loads(line.strip()) # Read each line and parse as JSON 21 | 22 | if 'video' in item: # If it contains a "video" field 23 | if video_dir not in item['video']: 24 | item['video'] = os.path.join(video_dir, item['video']) 25 | assert os.path.exists(item['video']) 26 | 27 | # Retain only necessary elements 28 | item['question'] = item['prompt'] 29 | item['response'] = item['chosen'] 30 | item['rejected_response'] = item['rejected'] 31 | # Delete unnecessary fields 32 | del item['prompt'], item['chosen'], item['rejected'] 33 | 34 | # Remove all other fields except the ones we want to retain 35 | for key in list(item.keys()): 36 | if key not in ['video', 'question', 'response', 'rejected_response']: 37 | del item[key] 38 | 39 | # Write the filtered item to the video output file 40 | json.dump(item, video_file) 41 | video_file.write('\n') 42 | 43 | else: # Otherwise save as image file 44 | if image_dir not in item['image']: 45 | item['image'] = os.path.join(image_dir, item['image']) 46 | assert os.path.exists(item['image']) 47 | 48 | # Retain only necessary elements 49 | item['question'] = item['prompt'] 50 | item['response'] = item['chosen'] 51 | item['rejected_response'] = item['rejected'] 52 | # Delete unnecessary fields 53 | del item['prompt'], item['chosen'], item['rejected'] 54 | 55 | # Remove all other fields except the ones we want to retain 56 | for key in list(item.keys()): 57 | if key not in ['image', 'question', 'response', 'rejected_response']: 58 | del item[key] 59 | 60 | # Write the filtered item to the image output file 61 | json.dump(item, image_file) 62 | image_file.write('\n') 63 | -------------------------------------------------------------------------------- /scripts/swift.sh: -------------------------------------------------------------------------------- 1 | # after you install the swift package, you can use the swift command to run swift code 2 | 3 | VIDEO_SEGMENTS=8 \ 4 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 5 | NPROC_PER_NODE=8 \ 6 | swift rlhf \ 7 | --rlhf_type dpo \ 8 | --model OpenGVLab/InternVL2-2B \ 9 | --dataset tmp/mmrlhf_v1_video.jsonl tmp/mmrlhf_v1_image.jsonl \ 10 | --beta 0.2 \ 11 | --rpo_alpha $rpo_alpha \ 12 | --train_type full \ 13 | --deepspeed zero2 \ 14 | --torch_dtype bfloat16 \ 15 | --num_train_epochs 1 \ 16 | --per_device_train_batch_size 1 \ 17 | --gradient_accumulation_steps 32 \ 18 | --learning_rate $learning_rate \ 19 | --freeze_vit true \ 20 | --eval_steps 1000 \ 21 | --save_steps 1000 \ 22 | --save_total_limit 5 \ 23 | --logging_steps 5 \ 24 | --max_length 32768 \ 25 | --output_dir $output_dir \ 26 | --warmup_ratio 0.05 \ 27 | --dataloader_num_workers 4 \ 28 | --report_to wandb -------------------------------------------------------------------------------- /scripts/train/critic_reward_7b.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | # export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} 5 | export NCCL_SOCKET_IFNAME=eth0 6 | export NCCL_DEBUG=INFO 7 | 8 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 9 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 10 | 11 | # DPO Stage 12 | PROMPT_VERSION="qwen_1_5" 13 | SFT_MODEL="lmms-lab/llava-onevision-qwen2-7b-ov" 14 | EPOCH=1 15 | DPO_RUN_NAME="llava-ov-reward-qwen2-7b-ov_mmrlhf-epoch${EPOCH}" 16 | DPO_CLEAN_NAME="${DPO_RUN_NAME##*/}" 17 | OUTPUT_DIR="/${DPO_CLEAN_NAME}" 18 | DATA_PATH="" 19 | 20 | echo $DPO_RUN_NAME 21 | 22 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 23 | llava/train/train_dpo.py \ 24 | --deepspeed scripts/zero3.json \ 25 | --model_name_or_path=${SFT_MODEL} \ 26 | --version $PROMPT_VERSION \ 27 | --data_path=$DATA_PATH \ 28 | --image_folder "" \ 29 | --video_folder "" \ 30 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model" \ 31 | --vision_tower ${VISION_MODEL_VERSION} \ 32 | --mm_projector_type mlp2x_gelu \ 33 | --mm_vision_select_layer -2 \ 34 | --mm_spatial_pool_mode bilinear \ 35 | --mm_use_im_start_end False \ 36 | --mm_use_im_patch_token False \ 37 | --group_by_modality_length True \ 38 | --image_aspect_ratio anyres_max_9 \ 39 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 40 | --mm_patch_merge_type spatial_unpad \ 41 | --bf16 True \ 42 | --run_name $DPO_CLEAN_NAME \ 43 | --output_dir $OUTPUT_DIR \ 44 | --num_train_epochs $EPOCH \ 45 | --per_device_train_batch_size 1 \ 46 | --per_device_eval_batch_size 1 \ 47 | --gradient_accumulation_steps 24 \ 48 | --evaluation_strategy "no" \ 49 | --save_strategy "steps" \ 50 | --save_steps 1000 \ 51 | --save_total_limit 1 \ 52 | --learning_rate 1e-6 \ 53 | --weight_decay 0. \ 54 | --warmup_ratio 0.1 \ 55 | --lr_scheduler_type "cosine" \ 56 | --logging_steps 1 \ 57 | --tf32 True \ 58 | --model_max_length 32768 \ 59 | --gradient_checkpointing True \ 60 | --dataloader_num_workers 4 \ 61 | --lazy_preprocess True \ 62 | --report_to wandb \ 63 | --dataloader_drop_last True \ 64 | --attn_implementation flash_attention_2 \ 65 | --is_rm True \ 66 | --critic_rewards_weight 1.0 \ 67 | --float_rewards_weight 1.0 \ 68 | --center_rewards_coefficient 0.01 \ 69 | -------------------------------------------------------------------------------- /scripts/train/dpo_ov7b.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | export NCCL_IB_DISABLE=1 5 | 6 | # export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} 7 | export NCCL_SOCKET_IFNAME=eth0 8 | export NCCL_DEBUG=INFO 9 | 10 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 11 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 12 | 13 | # DPO Stage 14 | PROMPT_VERSION="qwen_1_5" 15 | SFT_MODEL="lmms-lab/llava-onevision-qwen2-7b-ov" 16 | EPOCH=1 17 | beta=0.1 18 | ls_factor_weight=0.1 19 | DPO_RUN_NAME="llava-onevision-qwen2-7b-ov_mmrlhf-w${ls_factor_weight}-beta${beta}-epoch${EPOCH}" 20 | DPO_CLEAN_NAME="${DPO_RUN_NAME##*/}" 21 | OUTPUT_DIR="/${DPO_CLEAN_NAME}" 22 | DATA_PATH="" 23 | 24 | echo $DPO_RUN_NAME 25 | 26 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 27 | llava/train/train_dpo.py \ 28 | --deepspeed scripts/zero3.json \ 29 | --model_name_or_path=${SFT_MODEL} \ 30 | --dpo_alpha 1.0 --beta 0.1 --gamma 0.25 \ 31 | --ls_factor_weight $ls_factor_weight \ 32 | --version $PROMPT_VERSION \ 33 | --data_path=$DATA_PATH \ 34 | --image_folder "" \ 35 | --video_folder "" \ 36 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model" \ 37 | --vision_tower ${VISION_MODEL_VERSION} \ 38 | --mm_projector_type mlp2x_gelu \ 39 | --mm_vision_select_layer -2 \ 40 | --mm_spatial_pool_mode bilinear \ 41 | --mm_use_im_start_end False \ 42 | --mm_use_im_patch_token False \ 43 | --group_by_modality_length True \ 44 | --image_aspect_ratio anyres_max_9 \ 45 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 46 | --mm_patch_merge_type spatial_unpad \ 47 | --bf16 True \ 48 | --run_name $DPO_CLEAN_NAME \ 49 | --output_dir $OUTPUT_DIR \ 50 | --num_train_epochs $EPOCH \ 51 | --per_device_train_batch_size 1 \ 52 | --per_device_eval_batch_size 1 \ 53 | --gradient_accumulation_steps 12 \ 54 | --evaluation_strategy "no" \ 55 | --save_strategy "steps" \ 56 | --save_steps 1000 \ 57 | --save_total_limit 1 \ 58 | --learning_rate 1e-6 \ 59 | --weight_decay 0. \ 60 | --warmup_ratio 0.1 \ 61 | --lr_scheduler_type "cosine" \ 62 | --logging_steps 1 \ 63 | --tf32 True \ 64 | --model_max_length 32768 \ 65 | --gradient_checkpointing True \ 66 | --dataloader_num_workers 4 \ 67 | --lazy_preprocess True \ 68 | --report_to wandb \ 69 | --dataloader_drop_last True \ 70 | --attn_implementation sdpa 71 | -------------------------------------------------------------------------------- /scripts/train/generate_ref_logits.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=8 2 | export NCCL_IB_DISABLE=0 3 | export NCCL_IB_GID_INDEX=3 4 | # export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} 5 | export NCCL_SOCKET_IFNAME=eth0 6 | export NCCL_DEBUG=INFO 7 | 8 | VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" 9 | VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" 10 | 11 | # DPO Stage 12 | PROMPT_VERSION="qwen_1_5" 13 | SFT_MODEL="lmms-lab/llava-onevision-qwen2-7b-ov" 14 | EPOCH=1 15 | beta=0.1 16 | ls_factor_weight=0.1 17 | DPO_RUN_NAME="llava-onevision-qwen2-7b-ov_mmrlhf-w${ls_factor_weight}-beta${beta}-epoch${EPOCH}" 18 | DPO_CLEAN_NAME="${DPO_RUN_NAME##*/}" 19 | OUTPUT_DIR="/${DPO_CLEAN_NAME}" 20 | DATA_PATH="" 21 | OUTPUT_DATA_PATH="" 22 | 23 | echo $DPO_RUN_NAME 24 | 25 | ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ 26 | llava/train/train_dpo.py \ 27 | --deepspeed scripts/zero2.json \ 28 | --model_name_or_path=${SFT_MODEL} \ 29 | --dpo_alpha 1.0 --beta 0.1 --gamma 0.25 \ 30 | --ls_factor_weight $ls_factor_weight \ 31 | --version $PROMPT_VERSION \ 32 | --data_path=$DATA_PATH \ 33 | --image_folder "" \ 34 | --video_folder "" \ 35 | --mm_tunable_parts="mm_mlp_adapter,mm_language_model" \ 36 | --vision_tower ${VISION_MODEL_VERSION} \ 37 | --mm_projector_type mlp2x_gelu \ 38 | --mm_vision_select_layer -2 \ 39 | --mm_spatial_pool_mode bilinear \ 40 | --mm_use_im_start_end False \ 41 | --mm_use_im_patch_token False \ 42 | --group_by_modality_length True \ 43 | --image_aspect_ratio anyres_max_9 \ 44 | --image_grid_pinpoints "(1x1),...,(6x6)" \ 45 | --mm_patch_merge_type spatial_unpad \ 46 | --bf16 True \ 47 | --run_name $DPO_CLEAN_NAME \ 48 | --output_dir $OUTPUT_DIR \ 49 | --num_train_epochs $EPOCH \ 50 | --per_device_train_batch_size 1 \ 51 | --per_device_eval_batch_size 1 \ 52 | --gradient_accumulation_steps 12 \ 53 | --evaluation_strategy "no" \ 54 | --save_strategy "steps" \ 55 | --save_steps 1000 \ 56 | --save_total_limit 1 \ 57 | --learning_rate 1e-6 \ 58 | --weight_decay 0. \ 59 | --warmup_ratio 0.1 \ 60 | --lr_scheduler_type "cosine" \ 61 | --logging_steps 1 \ 62 | --tf32 True \ 63 | --model_max_length 32768 \ 64 | --gradient_checkpointing True \ 65 | --dataloader_num_workers 4 \ 66 | --lazy_preprocess True \ 67 | --report_to wandb \ 68 | --dataloader_drop_last True \ 69 | --attn_implementation sdpa \ 70 | --ref_data $OUTPUT_DATA_PATH \ 71 | --precompute_ref_log_probs True \ 72 | 73 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_fused_adamw.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": true, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu", 20 | "pin_memory": true 21 | }, 22 | "offload_param": { 23 | "device": "cpu", 24 | "pin_memory": true 25 | }, 26 | "overlap_comm": true, 27 | "contiguous_gradients": true, 28 | "sub_group_size": 1e9, 29 | "reduce_bucket_size": "auto" 30 | } 31 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /scripts/zero3pp.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "none", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "none", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "zero_quantized_weights": true, 36 | "zero_hpz_partition_size": 16, 37 | "zero_quantized_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | 47 | "gradient_accumulation_steps": "auto", 48 | "gradient_clipping": "auto", 49 | "steps_per_print": 100, 50 | "train_batch_size": "auto", 51 | "train_micro_batch_size_per_gpu": "auto", 52 | "wall_clock_breakdown": false 53 | } -------------------------------------------------------------------------------- /trl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | __version__ = "0.7.11.dev0" 4 | 5 | from .core import set_seed 6 | from .environment import TextEnvironment, TextHistory 7 | from .extras import BestOfNSampler 8 | from .import_utils import ( 9 | is_bitsandbytes_available, 10 | is_diffusers_available, 11 | is_npu_available, 12 | is_peft_available, 13 | is_wandb_available, 14 | is_xpu_available, 15 | ) 16 | from .models import ( 17 | AutoModelForCausalLMWithValueHead, 18 | AutoModelForSeq2SeqLMWithValueHead, 19 | PreTrainedModelWrapper, 20 | create_reference_model, 21 | setup_chat_format, 22 | ) 23 | from .trainer import ( 24 | DataCollatorForCompletionOnlyLM, 25 | DPOTrainer, 26 | IterativeSFTTrainer, 27 | ModelConfig, 28 | PPOConfig, 29 | PPOTrainer, 30 | RewardConfig, 31 | RewardTrainer, 32 | SFTTrainer, 33 | ) 34 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config 35 | 36 | 37 | if is_diffusers_available(): 38 | from .models import ( 39 | DDPOPipelineOutput, 40 | DDPOSchedulerOutput, 41 | DDPOStableDiffusionPipeline, 42 | DefaultDDPOStableDiffusionPipeline, 43 | ) 44 | from .trainer import DDPOConfig, DDPOTrainer 45 | -------------------------------------------------------------------------------- /trl/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .base_environment import TextEnvironment, TextHistory 4 | -------------------------------------------------------------------------------- /trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .best_of_n_sampler import BestOfNSampler 17 | -------------------------------------------------------------------------------- /trl/extras/best_of_n_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Union 2 | 3 | import torch 4 | from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast 5 | 6 | from ..core import set_seed 7 | from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper 8 | 9 | 10 | class BestOfNSampler(object): 11 | def __init__( 12 | self, 13 | model: PreTrainedModelWrapper, 14 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 15 | queries_to_scores: Callable[[List[str]], List[float]], 16 | length_sampler: Any, 17 | sample_size: int = 4, 18 | seed: Optional[int] = None, 19 | n_candidates: int = 1, 20 | generation_config: Optional[GenerationConfig] = None, 21 | ) -> None: 22 | r""" 23 | Initialize the sampler for best-of-n generation 24 | 25 | Args: 26 | model (`PreTrainedModelWrapper`): 27 | The pretrained model to use for generation 28 | tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): 29 | Tokenizer associated with the pretrained model 30 | queries_to_scores (`Callable[[List[str]], List[float]]`): 31 | Callable that takes a list of generated texts and returns the associated reward scores 32 | length_sampler (`Any`): 33 | Sampler used to sample the length of the generated text 34 | sample_size (`int`): 35 | Number of samples to generate for each query 36 | seed (`int`, *optional*): 37 | Random seed used to control generation 38 | n_candidates (`int`): 39 | Number of candidates to return for each query 40 | generation_config (`GenerationConfig`, *optional*): 41 | Generation config passed to the underlying model's `generate` method. 42 | See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details 43 | """ 44 | if seed is not None: 45 | set_seed(seed) 46 | 47 | if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): 48 | raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}") 49 | if not isinstance(model, (SUPPORTED_ARCHITECTURES)): 50 | raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") 51 | 52 | self.model = model 53 | self.tokenizer = tokenizer 54 | 55 | self.queries_to_scores = queries_to_scores 56 | self.length_sampler = length_sampler 57 | self.gen_config = generation_config 58 | self.sample_size = sample_size 59 | self.n_candidates = n_candidates 60 | 61 | def generate( 62 | self, 63 | tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], 64 | skip_special_tokens: bool = True, 65 | device: Optional[Union[str, torch.device]] = None, 66 | **generation_kwargs, 67 | ) -> List[List[str]]: 68 | r""" 69 | Generate the best of n samples for input queries 70 | 71 | Args: 72 | tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): 73 | represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) 74 | skip_special_tokens (`bool`): 75 | Whether to remove the special tokens from the output 76 | device (`str` or `torch.device`, *optional*): 77 | The device on which the model will be loaded 78 | **generation_kwargs (`dict`, *optional*): 79 | Additional keyword arguments passed along to the underlying model's `generate` method. 80 | This is used to override generation config 81 | 82 | Returns: 83 | List[List[str]]: A list of lists of generated texts 84 | """ 85 | queries = None 86 | 87 | if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: 88 | queries = tokenized_query.unsqueeze(0) 89 | elif isinstance(tokenized_query, List): 90 | element_type = type(tokenized_query[0]) 91 | if element_type == int: 92 | queries = torch.tensor(tokenized_query).unsqueeze(0) 93 | elif element_type == torch.Tensor: 94 | queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] 95 | else: 96 | queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] 97 | 98 | result = [] 99 | 100 | for query in queries: 101 | queries = query.repeat((self.sample_size, 1)) 102 | output = self.model.generate( 103 | queries.to(device), 104 | max_new_tokens=self.length_sampler(), 105 | generation_config=self.gen_config, 106 | **generation_kwargs, 107 | ).squeeze() 108 | output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) 109 | scores = torch.tensor(self.queries_to_scores(output)) 110 | output = [output[i] for i in scores.topk(self.n_candidates).indices] 111 | result.append(output) 112 | 113 | return result 114 | -------------------------------------------------------------------------------- /trl/extras/dataset_formatting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Literal, Optional, Union 3 | 4 | from datasets import Dataset, Value 5 | from transformers import AutoTokenizer 6 | 7 | from ..trainer.utils import ConstantLengthDataset 8 | 9 | 10 | FORMAT_MAPPING = { 11 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], 12 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 13 | } 14 | 15 | 16 | def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): 17 | r""" 18 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer 19 | apply chat template to the dataset 20 | """ 21 | 22 | def format_dataset(examples): 23 | if isinstance(examples[messages_field][0], list): 24 | output_texts = [] 25 | for i in range(len(examples[messages_field])): 26 | output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) 27 | return output_texts 28 | else: 29 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) 30 | 31 | return format_dataset 32 | 33 | 34 | def instructions_formatting_function(tokenizer: AutoTokenizer): 35 | r""" 36 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer 37 | apply chat template to the dataset 38 | """ 39 | 40 | def format_dataset(examples): 41 | if isinstance(examples["prompt"], list): 42 | output_texts = [] 43 | for i in range(len(examples["prompt"])): 44 | converted_sample = [ 45 | {"role": "user", "content": examples["prompt"][i]}, 46 | {"role": "assistant", "content": examples["completion"][i]}, 47 | ] 48 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) 49 | return output_texts 50 | else: 51 | converted_sample = [ 52 | {"role": "user", "content": examples["prompt"]}, 53 | {"role": "assistant", "content": examples["completion"]}, 54 | ] 55 | return tokenizer.apply_chat_template(converted_sample, tokenize=False) 56 | 57 | return format_dataset 58 | 59 | 60 | def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]: 61 | r""" 62 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are: 63 | - `ChatML` with [{"role": str, "content": str}] 64 | - `instruction` with [{"prompt": str, "completion": str}] 65 | 66 | Args: 67 | dataset (Dataset): User dataset 68 | tokenizer (AutoTokenizer): Tokenizer used for formatting 69 | 70 | Returns: 71 | Callable: Formatting function if the dataset format is supported else None 72 | """ 73 | if isinstance(dataset, Dataset): 74 | if "messages" in dataset.features: 75 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: 76 | logging.info("Formatting dataset with chatml format") 77 | return conversations_formatting_function(tokenizer, "messages") 78 | if "conversations" in dataset.features: 79 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: 80 | logging.info("Formatting dataset with chatml format") 81 | return conversations_formatting_function(tokenizer, "conversations") 82 | elif dataset.features == FORMAT_MAPPING["instruction"]: 83 | logging.info("Formatting dataset with instruction format") 84 | return instructions_formatting_function(tokenizer) 85 | 86 | return None 87 | -------------------------------------------------------------------------------- /trl/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import sys 16 | 17 | 18 | if sys.version_info < (3, 8): 19 | _is_python_greater_3_8 = False 20 | else: 21 | _is_python_greater_3_8 = True 22 | 23 | 24 | def is_peft_available() -> bool: 25 | return importlib.util.find_spec("peft") is not None 26 | 27 | 28 | def is_unsloth_available() -> bool: 29 | return importlib.util.find_spec("unsloth") is not None 30 | 31 | 32 | def is_accelerate_greater_20_0() -> bool: 33 | if _is_python_greater_3_8: 34 | from importlib.metadata import version 35 | 36 | accelerate_version = version("accelerate") 37 | else: 38 | import pkg_resources 39 | 40 | accelerate_version = pkg_resources.get_distribution("accelerate").version 41 | return accelerate_version >= "0.20.0" 42 | 43 | 44 | def is_transformers_greater_than(version: str) -> bool: 45 | _transformers_version = importlib.metadata.version("transformers") 46 | return _transformers_version > version 47 | 48 | 49 | def is_torch_greater_2_0() -> bool: 50 | if _is_python_greater_3_8: 51 | from importlib.metadata import version 52 | 53 | torch_version = version("torch") 54 | else: 55 | import pkg_resources 56 | 57 | torch_version = pkg_resources.get_distribution("torch").version 58 | return torch_version >= "2.0" 59 | 60 | 61 | def is_diffusers_available() -> bool: 62 | return importlib.util.find_spec("diffusers") is not None 63 | 64 | 65 | def is_bitsandbytes_available() -> bool: 66 | import torch 67 | 68 | # bnb can be imported without GPU but is not usable. 69 | return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() 70 | 71 | 72 | def is_torchvision_available() -> bool: 73 | return importlib.util.find_spec("torchvision") is not None 74 | 75 | 76 | def is_rich_available() -> bool: 77 | return importlib.util.find_spec("rich") is not None 78 | 79 | 80 | def is_wandb_available() -> bool: 81 | return importlib.util.find_spec("wandb") is not None 82 | 83 | 84 | def is_xpu_available() -> bool: 85 | if is_accelerate_greater_20_0(): 86 | import accelerate 87 | 88 | return accelerate.utils.is_xpu_available() 89 | else: 90 | if importlib.util.find_spec("intel_extension_for_pytorch") is None: 91 | return False 92 | try: 93 | import torch 94 | 95 | return hasattr(torch, "xpu") and torch.xpu.is_available() 96 | except RuntimeError: 97 | return False 98 | 99 | 100 | def is_npu_available() -> bool: 101 | """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" 102 | if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: 103 | return False 104 | 105 | import torch 106 | import torch_npu # noqa: F401 107 | 108 | return hasattr(torch, "npu") and torch.npu.is_available() 109 | -------------------------------------------------------------------------------- /trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .modeling_base import PreTrainedModelWrapper, create_reference_model 17 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 18 | from .utils import setup_chat_format 19 | 20 | 21 | SUPPORTED_ARCHITECTURES = ( 22 | AutoModelForCausalLMWithValueHead, 23 | AutoModelForSeq2SeqLMWithValueHead, 24 | ) 25 | 26 | from ..import_utils import is_diffusers_available 27 | 28 | 29 | if is_diffusers_available(): 30 | from .modeling_sd_base import ( 31 | DDPOPipelineOutput, 32 | DDPOSchedulerOutput, 33 | DDPOStableDiffusionPipeline, 34 | DefaultDDPOStableDiffusionPipeline, 35 | ) 36 | -------------------------------------------------------------------------------- /trl/models/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Tuple 3 | 4 | from transformers import PreTrainedModel, PreTrainedTokenizer 5 | 6 | 7 | # TODO: Add Abstract Base Class if more formats are added 8 | @dataclass 9 | class ChatMlSpecialTokens: 10 | """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" 11 | 12 | bos_token: str = "<|im_start|>" 13 | eos_token: str = "<|im_end|>" 14 | pad_token: str = "<|im_end|>" 15 | 16 | @property 17 | def system(self): 18 | return f"{self.bos_token}system" 19 | 20 | @property 21 | def user(self): 22 | return f"{self.bos_token}user" 23 | 24 | @property 25 | def assistant(self): 26 | return f"{self.bos_token}assistant" 27 | 28 | @property 29 | def chat_template(self): 30 | return ( 31 | "{% for message in messages %}" 32 | f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" 33 | "{% endfor %}" 34 | "{% if add_generation_prompt %}" 35 | f"{{{{ '{self.assistant}\n' }}}}" 36 | "{% endif %}" 37 | ) 38 | 39 | 40 | FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} 41 | 42 | 43 | def setup_chat_format( 44 | model: PreTrainedModel, 45 | tokenizer: PreTrainedTokenizer, 46 | format: Optional[Literal["chatml"]] = "chatml", 47 | resize_to_multiple_of: Optional[int] = None, 48 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 49 | """ 50 | Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. 51 | 52 | Args: 53 | model (`~transformers.PreTrainedModel`): The model to be modified. 54 | tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. 55 | format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". 56 | resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. 57 | Returns: 58 | model (`~transformers.PreTrainedModel`): The modified model. 59 | tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. 60 | """ 61 | # check if format available and retrieve 62 | if format not in FORMAT_MAPPING: 63 | raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") 64 | 65 | chat_format = FORMAT_MAPPING[format]() 66 | 67 | # set special tokens and them 68 | tokenizer.eos_token = chat_format.eos_token 69 | tokenizer.pad_token = chat_format.pad_token 70 | tokenizer.bos_token = chat_format.bos_token 71 | tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) 72 | # set chat format for tokenizer 73 | tokenizer.chat_template = chat_format.chat_template 74 | 75 | # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 76 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) 77 | # Make sure to update the generation config to use the new eos & bos token 78 | if getattr(model, "generation_config", None) is not None: 79 | model.generation_config.bos_token_id = tokenizer.bos_token_id 80 | model.generation_config.eos_token_id = tokenizer.eos_token_id 81 | model.generation_config.pad_token_id = tokenizer.pad_token_id 82 | 83 | return model, tokenizer 84 | -------------------------------------------------------------------------------- /trl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # There is a circular import in the PPOTrainer if we let isort sort these 18 | # isort: off 19 | from .utils import ( 20 | AdaptiveKLController, 21 | FixedKLController, 22 | ConstantLengthDataset, 23 | DataCollatorForCompletionOnlyLM, 24 | RunningMoments, 25 | disable_dropout_in_model, 26 | peft_module_casting_to_bf16, 27 | ) 28 | 29 | # isort: on 30 | 31 | from ..import_utils import is_diffusers_available 32 | from .base import BaseTrainer 33 | from .ddpo_config import DDPOConfig 34 | 35 | 36 | if is_diffusers_available(): 37 | from .ddpo_trainer import DDPOTrainer 38 | 39 | from .dpo_trainer import DPOTrainer 40 | from .iterative_sft_trainer import IterativeSFTTrainer 41 | from .model_config import ModelConfig 42 | from .ppo_config import PPOConfig 43 | from .ppo_trainer import PPOTrainer 44 | from .reward_config import RewardConfig 45 | from .reward_trainer import RewardTrainer, compute_accuracy 46 | from .sft_trainer import SFTTrainer 47 | from .dpo_mix_trainer import DPOMixTrainer 48 | -------------------------------------------------------------------------------- /trl/trainer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from huggingface_hub import PyTorchModelHubMixin 16 | 17 | 18 | class BaseTrainer(PyTorchModelHubMixin): 19 | r""" 20 | Base class for all trainers - this base class implements the basic functions that we 21 | need for a trainer. 22 | 23 | The trainer needs to have the following functions: 24 | - step: takes in a batch of data and performs a step of training 25 | - loss: takes in a batch of data and returns the loss 26 | - compute_rewards: takes in a batch of data and returns the rewards 27 | - _build_models_and_tokenizer: builds the models and tokenizer 28 | - _build_dataset: builds the dataset 29 | Each user is expected to implement their own trainer class that inherits from this base 30 | if they want to use a new training algorithm. 31 | """ 32 | 33 | def __init__(self, config): 34 | self.config = config 35 | 36 | def step(self, *args): 37 | raise NotImplementedError("Not implemented") 38 | 39 | def loss(self, *args): 40 | raise NotImplementedError("Not implemented") 41 | 42 | def compute_rewards(self, *args): 43 | raise NotImplementedError("Not implemented") 44 | 45 | def _save_pretrained(self, save_directory): 46 | raise NotImplementedError("Not implemented") 47 | -------------------------------------------------------------------------------- /trl/trainer/ddpo_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from dataclasses import dataclass, field 5 | from typing import Literal, Optional 6 | 7 | from ..core import flatten_dict 8 | from ..import_utils import is_bitsandbytes_available, is_torchvision_available 9 | 10 | 11 | @dataclass 12 | class DDPOConfig: 13 | """ 14 | Configuration class for DDPOTrainer 15 | """ 16 | 17 | # common parameters 18 | exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] 19 | """the name of this experiment (by default is the file name without the extension name)""" 20 | run_name: Optional[str] = "" 21 | """Run name for wandb logging and checkpoint saving.""" 22 | seed: int = 0 23 | """Seed value for random generations""" 24 | log_with: Optional[Literal["wandb", "tensorboard"]] = None 25 | """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" 26 | tracker_kwargs: dict = field(default_factory=dict) 27 | """Keyword arguments for the tracker (e.g. wandb_project)""" 28 | accelerator_kwargs: dict = field(default_factory=dict) 29 | """Keyword arguments for the accelerator""" 30 | project_kwargs: dict = field(default_factory=dict) 31 | """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" 32 | tracker_project_name: str = "trl" 33 | """Name of project to use for tracking""" 34 | logdir: str = "logs" 35 | """Top-level logging directory for checkpoint saving.""" 36 | 37 | # hyperparameters 38 | num_epochs: int = 100 39 | """Number of epochs to train.""" 40 | save_freq: int = 1 41 | """Number of epochs between saving model checkpoints.""" 42 | num_checkpoint_limit: int = 5 43 | """Number of checkpoints to keep before overwriting old ones.""" 44 | mixed_precision: str = "fp16" 45 | """Mixed precision training.""" 46 | allow_tf32: bool = True 47 | """Allow tf32 on Ampere GPUs.""" 48 | resume_from: Optional[str] = "" 49 | """Resume training from a checkpoint.""" 50 | sample_num_steps: int = 50 51 | """Number of sampler inference steps.""" 52 | sample_eta: float = 1.0 53 | """Eta parameter for the DDIM sampler.""" 54 | sample_guidance_scale: float = 5.0 55 | """Classifier-free guidance weight.""" 56 | sample_batch_size: int = 1 57 | """Batch size (per GPU!) to use for sampling.""" 58 | sample_num_batches_per_epoch: int = 2 59 | """Number of batches to sample per epoch.""" 60 | train_batch_size: int = 1 61 | """Batch size (per GPU!) to use for training.""" 62 | train_use_8bit_adam: bool = False 63 | """Whether to use the 8bit Adam optimizer from bitsandbytes.""" 64 | train_learning_rate: float = 3e-4 65 | """Learning rate.""" 66 | train_adam_beta1: float = 0.9 67 | """Adam beta1.""" 68 | train_adam_beta2: float = 0.999 69 | """Adam beta2.""" 70 | train_adam_weight_decay: float = 1e-4 71 | """Adam weight decay.""" 72 | train_adam_epsilon: float = 1e-8 73 | """Adam epsilon.""" 74 | train_gradient_accumulation_steps: int = 1 75 | """Number of gradient accumulation steps.""" 76 | train_max_grad_norm: float = 1.0 77 | """Maximum gradient norm for gradient clipping.""" 78 | train_num_inner_epochs: int = 1 79 | """Number of inner epochs per outer epoch.""" 80 | train_cfg: bool = True 81 | """Whether or not to use classifier-free guidance during training.""" 82 | train_adv_clip_max: float = 5 83 | """Clip advantages to the range.""" 84 | train_clip_range: float = 1e-4 85 | """The PPO clip range.""" 86 | train_timestep_fraction: float = 1.0 87 | """The fraction of timesteps to train on.""" 88 | per_prompt_stat_tracking: bool = False 89 | """Whether to track statistics for each prompt separately.""" 90 | per_prompt_stat_tracking_buffer_size: int = 16 91 | """Number of reward values to store in the buffer for each prompt.""" 92 | per_prompt_stat_tracking_min_count: int = 16 93 | """The minimum number of reward values to store in the buffer.""" 94 | async_reward_computation: bool = False 95 | """Whether to compute rewards asynchronously.""" 96 | max_workers: int = 2 97 | """The maximum number of workers to use for async reward computation.""" 98 | negative_prompts: Optional[str] = "" 99 | """Comma-separated list of prompts to use as negative examples.""" 100 | 101 | def to_dict(self): 102 | output_dict = {} 103 | for key, value in self.__dict__.items(): 104 | output_dict[key] = value 105 | return flatten_dict(output_dict) 106 | 107 | def __post_init__(self): 108 | if self.log_with not in ["wandb", "tensorboard"]: 109 | warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.")) 110 | 111 | if self.log_with == "wandb" and not is_torchvision_available(): 112 | warnings.warn("Wandb image logging requires torchvision to be installed") 113 | 114 | if self.train_use_8bit_adam and not is_bitsandbytes_available(): 115 | raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.") 116 | -------------------------------------------------------------------------------- /trl/trainer/model_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | from ..core import flatten_dict 5 | 6 | 7 | @dataclass 8 | class ModelConfig: 9 | """ 10 | Arguments which define the model and tokenizer to load. 11 | """ 12 | 13 | model_name_or_path: Optional[str] = field( 14 | default=None, 15 | metadata={"help": ("The model checkpoint for weights initialization.")}, 16 | ) 17 | model_revision: str = field( 18 | default="main", 19 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 20 | ) 21 | torch_dtype: Optional[str] = field( 22 | default=None, 23 | metadata={ 24 | "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), 25 | "choices": ["auto", "bfloat16", "float16", "float32"], 26 | }, 27 | ) 28 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) 29 | attn_implementation: Optional[str] = field( 30 | default=None, 31 | metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, 32 | ) 33 | use_peft: bool = field( 34 | default=False, 35 | metadata={"help": ("Whether to use PEFT or not for training.")}, 36 | ) 37 | lora_r: Optional[int] = field( 38 | default=16, 39 | metadata={"help": ("LoRA R value.")}, 40 | ) 41 | lora_alpha: Optional[int] = field( 42 | default=32, 43 | metadata={"help": ("LoRA alpha.")}, 44 | ) 45 | lora_dropout: Optional[float] = field( 46 | default=0.05, 47 | metadata={"help": ("LoRA dropout.")}, 48 | ) 49 | lora_target_modules: Optional[List[str]] = field( 50 | default=None, 51 | metadata={"help": ("LoRA target modules.")}, 52 | ) 53 | lora_modules_to_save: Optional[List[str]] = field( 54 | default=None, 55 | metadata={"help": ("Model layers to unfreeze & train")}, 56 | ) 57 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) 58 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) 59 | 60 | bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) 61 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) 62 | 63 | def to_dict(self): 64 | output_dict = {} 65 | for key, value in self.__dict__.items(): 66 | output_dict[key] = value 67 | return flatten_dict(output_dict) 68 | 69 | def __post_init__(self): 70 | if self.load_in_8bit and self.load_in_4bit: 71 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 72 | -------------------------------------------------------------------------------- /trl/trainer/reward_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class RewardConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`RewardTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | max_length (`Optional[int]`, *optional*, defaults to `None`): 32 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want 33 | to use the default data collator. 34 | dataset_num_proc (`int`, *optional*, defaults to `None`): 35 | Number of processes to use for processing the dataset. 36 | center_rewards_coefficient (`float`, *optional*, defaults to `None`): 37 | Coefficient to incentivize the reward model to output mean-zero rewards (proposed by 38 | https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. 39 | remove_unused_columns (`bool`, *optional*, defaults to `False`): 40 | Whether or not to remove the columns that are not used by the model's forward pass. Can be `True` only if 41 | the dataset is pretokenized. 42 | """ 43 | 44 | max_length: Optional[int] = None 45 | dataset_num_proc: Optional[int] = None 46 | center_rewards_coefficient: Optional[float] = 0.01 47 | remove_unused_columns: bool = False --------------------------------------------------------------------------------