├── llava ├── serve │ ├── __init__.py │ ├── examples │ │ ├── waterview.jpg │ │ └── extreme_ironing.jpg │ ├── register_worker.py │ ├── test_message.py │ └── cli.py ├── __init__.py ├── train │ ├── train_mem.py │ ├── llava_trainer_eval.py │ └── llama_flash_attn_monkey_patch.py ├── model │ ├── multimodal_encoder │ │ ├── dev_eva_clip │ │ │ ├── eva_clip │ │ │ │ ├── constants.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ ├── model_configs │ │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14.json │ │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ │ └── EVA02-CLIP-L-14-336.json │ │ │ │ ├── __init__.py │ │ │ │ ├── hf_configs.py │ │ │ │ ├── transform.py │ │ │ │ ├── timm_model.py │ │ │ │ ├── rope.py │ │ │ │ ├── openai.py │ │ │ │ └── loss.py │ │ │ └── eva_vit.py │ │ ├── eva_clip │ │ │ ├── model_configs │ │ │ │ ├── EVA01-CLIP-B-16.json │ │ │ │ ├── EVA01-CLIP-g-14.json │ │ │ │ ├── EVA01-CLIP-g-14-plus.json │ │ │ │ ├── EVA02-CLIP-bigE-14.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14.json │ │ │ │ ├── EVA02-CLIP-bigE-14-plus.json │ │ │ │ ├── Internal-EVA02-CLIP-10B-14-448.json │ │ │ │ ├── EVA-CLIP-8B.json │ │ │ │ ├── EVA-CLIP-18B.json │ │ │ │ ├── EVA-CLIP-8B-plus.json │ │ │ │ ├── EVA02-CLIP-B-16.json │ │ │ │ ├── EVA02-CLIP-L-14.json │ │ │ │ └── EVA02-CLIP-L-14-336.json │ │ │ ├── factory.py │ │ │ ├── eva_clip_processors.py │ │ │ └── eva_clip_encoder.py │ │ ├── builder.py │ │ ├── imagebind.py │ │ └── hf_vision.py │ ├── __init__.py │ ├── utils.py │ ├── consolidate.py │ ├── multimodal_projector │ │ ├── pooler_projector.py │ │ └── builder.py │ ├── multimodal_resampler │ │ ├── builder.py │ │ ├── spatial_pool.py │ │ ├── masked_drop.py │ │ └── perceiver.py │ ├── apply_delta.py │ ├── make_delta.py │ └── language_model │ │ ├── llava_mpt.py │ │ ├── llava_gemma.py │ │ ├── llava_mistral.py │ │ └── llava_mixtral.py └── constants.py ├── images ├── method.png ├── overview.png └── benchmarks.png ├── trl ├── environment │ └── __init__.py ├── extras │ ├── __init__.py │ ├── dataset_formatting.py │ └── best_of_n_sampler.py ├── __init__.py ├── models │ ├── __init__.py │ └── utils.py ├── trainer │ ├── __init__.py │ ├── reward_config.py │ ├── base.py │ ├── model_config.py │ └── ddpo_config.py └── import_utils.py ├── scripts ├── inference_demo │ ├── codalm │ │ ├── codalm1.png │ │ ├── codalm2.png │ │ └── codalm3.png │ ├── bddx │ │ ├── 1 │ │ │ ├── testing_1f0fff77-a50aae97_23664_0.png │ │ │ ├── testing_1f0fff77-a50aae97_23664_11.png │ │ │ ├── testing_1f0fff77-a50aae97_23664_15.png │ │ │ ├── testing_1f0fff77-a50aae97_23664_3.png │ │ │ └── testing_1f0fff77-a50aae97_23664_7.png │ │ └── 2 │ │ │ ├── testing_1f13b7b2-e98c7699_23665_0.png │ │ │ ├── testing_1f13b7b2-e98c7699_23665_11.png │ │ │ ├── testing_1f13b7b2-e98c7699_23665_15.png │ │ │ ├── testing_1f13b7b2-e98c7699_23665_3.png │ │ │ └── testing_1f13b7b2-e98c7699_23665_7.png │ ├── lingoqa │ │ ├── 2a469a9042a47e4c68cadfaa7bdb4519 │ │ │ ├── 0.jpg │ │ │ ├── 1.jpg │ │ │ ├── 2.jpg │ │ │ ├── 3.jpg │ │ │ └── 4.jpg │ │ └── ab4845470b41f0e123da50c996c35745 │ │ │ ├── 0.jpg │ │ │ ├── 1.jpg │ │ │ ├── 2.jpg │ │ │ ├── 3.jpg │ │ │ └── 4.jpg │ ├── drivelm │ │ ├── n008-2018-08-30-10-33-52-0400__CAM_BACK__1535639717637558.jpg │ │ ├── n008-2018-08-30-10-33-52-0400__CAM_FRONT__1535639717612404.jpg │ │ ├── n008-2018-08-30-10-33-52-0400__CAM_BACK_LEFT__1535639717647405.jpg │ │ ├── n008-2018-08-30-10-33-52-0400__CAM_BACK_RIGHT__1535639717628113.jpg │ │ ├── n008-2018-08-30-10-33-52-0400__CAM_FRONT_LEFT__1535639717604799.jpg │ │ └── n008-2018-08-30-10-33-52-0400__CAM_FRONT_RIGHT__1535639717620482.jpg │ ├── demo_video.py │ └── demo_image.py ├── zero2_offload.json ├── qwen.py ├── zero2.json ├── zero2_fused_adamw.json ├── zero3.json ├── zero3_offload.json ├── zero3pp.json └── summarize_data.py ├── .gitattributes ├── cog.yaml ├── .gitignore ├── pyproject.toml ├── README.md └── predict.py /llava/serve/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /images/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/images/method.png -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/images/overview.png -------------------------------------------------------------------------------- /images/benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/images/benchmarks.png -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train() 5 | -------------------------------------------------------------------------------- /trl/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .base_environment import TextEnvironment, TextHistory 4 | -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/codalm/codalm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/codalm/codalm1.png -------------------------------------------------------------------------------- /scripts/inference_demo/codalm/codalm2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/codalm/codalm2.png -------------------------------------------------------------------------------- /scripts/inference_demo/codalm/codalm3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/codalm/codalm3.png -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_0.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_11.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_15.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_3.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/1/testing_1f0fff77-a50aae97_23664_7.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_0.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_11.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_15.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_3.png -------------------------------------------------------------------------------- /scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/bddx/2/testing_1f13b7b2-e98c7699_23665_7.png -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/0.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/1.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/2.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/3.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/4.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/0.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/1.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/2.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/3.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/lingoqa/ab4845470b41f0e123da50c996c35745/4.jpg -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK__1535639717637558.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK__1535639717637558.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT__1535639717612404.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT__1535639717612404.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK_LEFT__1535639717647405.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK_LEFT__1535639717647405.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK_RIGHT__1535639717628113.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK_RIGHT__1535639717628113.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT_LEFT__1535639717604799.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT_LEFT__1535639717604799.jpg -------------------------------------------------------------------------------- /scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT_RIGHT__1535639717620482.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhijian11/RoboTron-Drive/HEAD/scripts/inference_demo/drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT_RIGHT__1535639717620482.jpg -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | AVAILABLE_MODELS = { 4 | "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", 5 | "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", 6 | "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", 7 | "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", 8 | # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", 9 | # Add other models as needed 10 | } 11 | 12 | for model_name, model_classes in AVAILABLE_MODELS.items(): 13 | try: 14 | exec(f"from .language_model.{model_name} import {model_classes}") 15 | except Exception as e: 16 | print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}") 17 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .best_of_n_sampler import BestOfNSampler 17 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 77, 6 | "width": 2304, 7 | "head_width": 144, 8 | "mlp_ratio": 10.9722, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-10b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": false, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1536, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 5120, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-18b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": true, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": 32, 6 | "width": 4096, 7 | "head_width": 128, 8 | "mlp_ratio": 5, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-8b-14-plus-x", 11 | "drop_path_rate": 0, 12 | "qkv_bias": false, 13 | "xattn": true, 14 | "postnorm": false, 15 | "fusedLN": false, 16 | "use_rms_norm": true 17 | }, 18 | "text_cfg": { 19 | "context_length": 77, 20 | "vocab_size": 49408, 21 | "width": 1280, 22 | "heads": 20, 23 | "layers": 32, 24 | "xattn": false, 25 | "fusedLN": false 26 | } 27 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 6 | from .openai import load_openai_model, list_openai_models 7 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 8 | from .tokenizer import SimpleTokenizer, tokenize 9 | from .transform import image_transform 10 | -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu", 20 | "pin_memory": true 21 | }, 22 | "offload_param": { 23 | "device": "cpu", 24 | "pin_memory": true 25 | }, 26 | "overlap_comm": true, 27 | "contiguous_gradients": true, 28 | "sub_group_size": 1e9, 29 | "reduce_bucket_size": "auto" 30 | } 31 | } -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if "llava" in config and "llava" not in cfg.model_type: 7 | assert cfg.model_type == "llama" 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = "LlavaLlamaForCausalLM" 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /scripts/qwen.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | 4 | device = "cuda" # the device to load the model onto 5 | 6 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B-Chat", torch_dtype=torch.bfloat16, device_map="auto") 7 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B-Chat") 8 | 9 | prompt = "Give me a short introduction to large language model." 10 | messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}] 11 | text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 12 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 13 | 14 | generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512) 15 | generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] 16 | 17 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 18 | 19 | print(response) 20 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model import * 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def consolidate_ckpt(src_path, dst_path): 15 | print("Loading model") 16 | auto_upgrade(src_path) 17 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 18 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 19 | src_model.save_pretrained(dst_path) 20 | src_tokenizer.save_pretrained(dst_path) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--src", type=str, required=True) 26 | parser.add_argument("--dst", type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | 30 | consolidate_ckpt(args.src, args.dst) 31 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d(config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2) 16 | 17 | self.proj = nn.Sequential( 18 | nn.GELU(), 19 | nn.Linear(config.hidden_size, config.hidden_size), 20 | ) 21 | 22 | def forward(self, x, *args, **kwargs): 23 | height = width = self.hw 24 | assert height * width == x.shape[1] 25 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 26 | x = self.conv_pool(x) 27 | x = x.flatten(2).transpose(1, 2) 28 | x = self.proj(x) 29 | return x 30 | 31 | @property 32 | def config(self): 33 | return {"mm_projector_type": "pooler"} 34 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .masked_drop import MaskedDrop 4 | from .spatial_pool import SpatialPool 5 | from .perceiver import PerceiverResampler 6 | from .qformer import Qformer 7 | 8 | 9 | class IdentityMap(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_resampler_type": None} 19 | 20 | 21 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 22 | resampler_type = getattr(model_args, "mm_resampler_type", None) 23 | if resampler_type == "masked_drop": 24 | return MaskedDrop(model_args) 25 | elif resampler_type == "spatial_pool": 26 | return SpatialPool(model_args, **kwargs) 27 | elif resampler_type == "perceiver": 28 | return PerceiverResampler(model_args, **kwargs) 29 | elif resampler_type == "qformer": 30 | return Qformer(model_args, **kwargs) 31 | elif resampler_type is None: 32 | return IdentityMap() 33 | 34 | raise ValueError(f"Unknown resampler type: {resampler_type}") 35 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero2_fused_adamw.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": true, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /trl/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | __version__ = "0.7.11.dev0" 4 | 5 | from .core import set_seed 6 | from .environment import TextEnvironment, TextHistory 7 | from .extras import BestOfNSampler 8 | from .import_utils import ( 9 | is_bitsandbytes_available, 10 | is_diffusers_available, 11 | is_npu_available, 12 | is_peft_available, 13 | is_wandb_available, 14 | is_xpu_available, 15 | ) 16 | from .models import ( 17 | AutoModelForCausalLMWithValueHead, 18 | AutoModelForSeq2SeqLMWithValueHead, 19 | PreTrainedModelWrapper, 20 | create_reference_model, 21 | setup_chat_format, 22 | ) 23 | from .trainer import ( 24 | DataCollatorForCompletionOnlyLM, 25 | DPOTrainer, 26 | IterativeSFTTrainer, 27 | ModelConfig, 28 | PPOConfig, 29 | PPOTrainer, 30 | RewardConfig, 31 | RewardTrainer, 32 | SFTTrainer, 33 | ) 34 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config 35 | 36 | 37 | if is_diffusers_available(): 38 | from .models import ( 39 | DDPOPipelineOutput, 40 | DDPOSchedulerOutput, 41 | DDPOStableDiffusionPipeline, 42 | DefaultDDPOStableDiffusionPipeline, 43 | ) 44 | from .trainer import DDPOConfig, DDPOTrainer 45 | -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from .modeling_base import PreTrainedModelWrapper, create_reference_model 17 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 18 | from .utils import setup_chat_format 19 | 20 | 21 | SUPPORTED_ARCHITECTURES = ( 22 | AutoModelForCausalLMWithValueHead, 23 | AutoModelForSeq2SeqLMWithValueHead, 24 | ) 25 | 26 | from ..import_utils import is_diffusers_available 27 | 28 | 29 | if is_diffusers_available(): 30 | from .modeling_sd_base import ( 31 | DDPOPipelineOutput, 32 | DDPOSchedulerOutput, 33 | DDPOStableDiffusionPipeline, 34 | DefaultDDPOStableDiffusionPipeline, 35 | ) 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | # *.json 11 | # *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | # Editor 16 | .idea 17 | *.swp 18 | .vscode 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | llavavid 25 | 26 | checkpoints 27 | project_checkpoints 28 | debug_checkpoints 29 | playground/data 30 | playground/cc3m_llava34b_cap 31 | ckpts* 32 | 33 | .ipynb_checkpoints 34 | chunyl_scripts 35 | *.ipynb 36 | 37 | # DevContainer 38 | !.devcontainer/* 39 | 40 | # Demo 41 | serve_images/ 42 | notebooks/ 43 | logs 44 | scripts/dist_* 45 | logs/ 46 | submissions/ 47 | cn_scripts/ 48 | internal_project_checkpoints/ 49 | work_dirs 50 | scripts/i18n/* 51 | playground/.nfs028b000000010add00000001 52 | HIP 53 | playground/.nfs028b0000017bff2c00000012 54 | scripts/qwen 55 | scripts/vicuna 56 | scripts/mistral 57 | scripts/baseline_rep 58 | scripts/cn_boli01_hl 59 | scripts/cn_boli01_lf 60 | scripts/cn_lf 61 | scripts/cn_lq 62 | scripts/cn_yg 63 | scripts/cn_yg_hao 64 | scripts/eva_encoder 65 | scripts/i18n 66 | scripts/i18n_higher_res 67 | scripts/multi-images 68 | scratchpad 69 | build/ 70 | playground/*.json 71 | mlx_configs/ 72 | data_processing/ 73 | data/ 74 | onevisiondata/ 75 | # demo/ 76 | ckpt 77 | scripts/inference_demo/check* 78 | scripts/inference_demo/infer* 79 | markdown/ -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /scripts/zero3pp.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "none", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "none", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "zero_quantized_weights": true, 36 | "zero_hpz_partition_size": 16, 37 | "zero_quantized_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | }, 46 | 47 | "gradient_accumulation_steps": "auto", 48 | "gradient_clipping": "auto", 49 | "steps_per_print": 100, 50 | "train_batch_size": "auto", 51 | "train_micro_batch_size_per_gpu": "auto", 52 | "wall_clock_breakdown": false 53 | } -------------------------------------------------------------------------------- /trl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # There is a circular import in the PPOTrainer if we let isort sort these 18 | # isort: off 19 | from .utils import ( 20 | AdaptiveKLController, 21 | FixedKLController, 22 | ConstantLengthDataset, 23 | DataCollatorForCompletionOnlyLM, 24 | RunningMoments, 25 | disable_dropout_in_model, 26 | peft_module_casting_to_bf16, 27 | ) 28 | 29 | # isort: on 30 | 31 | from ..import_utils import is_diffusers_available 32 | from .base import BaseTrainer 33 | from .ddpo_config import DDPOConfig 34 | 35 | 36 | if is_diffusers_available(): 37 | from .ddpo_trainer import DDPOTrainer 38 | 39 | from .dpo_trainer import DPOTrainer 40 | from .iterative_sft_trainer import IterativeSFTTrainer 41 | from .model_config import ModelConfig 42 | from .ppo_config import PPOConfig 43 | from .ppo_trainer import PPOTrainer 44 | from .reward_config import RewardConfig 45 | from .reward_trainer import RewardTrainer, compute_accuracy 46 | from .sft_trainer import SFTTrainer 47 | -------------------------------------------------------------------------------- /trl/trainer/reward_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class RewardConfig(TrainingArguments): 23 | """ 24 | RewardConfig collects all training arguments related to the [`RewardTrainer`] class. 25 | 26 | Using [`HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | max_length (`int`, *optional*, defaults to `None`): 32 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 33 | gradient_checkpointing (`bool`, *optional*, defaults to `True`): 34 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 35 | """ 36 | 37 | max_length: Optional[int] = None 38 | """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" 39 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/spatial_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class SpatialPool(nn.Module): 7 | def __init__(self, model_args, vision_tower): 8 | super().__init__() 9 | 10 | self.mode = model_args.mm_spatial_pool_mode 11 | self.stride = model_args.mm_spatial_pool_stride 12 | self.out_channels = getattr(model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size) 13 | 14 | if self.mode == "average": 15 | self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride) 16 | elif self.mode == "max": 17 | self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride) 18 | elif self.mode == "conv": 19 | self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride) 20 | else: 21 | raise ValueError(f"Unknown pooling mode: {self.pool}.") 22 | 23 | def forward(self, image_features, images, *args, **kwargs): 24 | ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])) 25 | ori_H = int(ori_W * images.shape[2] // images.shape[3]) 26 | 27 | B, _, F = image_features.shape 28 | 29 | image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2) 30 | image_features_spatial_pool = self.pool(image_features_spatial) 31 | 32 | return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() 33 | 34 | @property 35 | def config(self): 36 | return { 37 | "mm_resampler_type": "spatial_pool", 38 | "mm_spatial_pool_stride": self.stride, 39 | "mm_spatial_pool_mode": self.mode, 40 | "mm_spatial_pool_out_channels": self.out_channels, 41 | } 42 | 43 | @property 44 | def hidden_size(self): 45 | return self.out_channels 46 | -------------------------------------------------------------------------------- /trl/trainer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from huggingface_hub import PyTorchModelHubMixin 16 | 17 | 18 | class BaseTrainer(PyTorchModelHubMixin): 19 | r""" 20 | Base class for all trainers - this base class implements the basic functions that we 21 | need for a trainer. 22 | 23 | The trainer needs to have the following functions: 24 | - step: takes in a batch of data and performs a step of training 25 | - loss: takes in a batch of data and returns the loss 26 | - compute_rewards: takes in a batch of data and returns the rewards 27 | - _build_models_and_tokenizer: builds the models and tokenizer 28 | - _build_dataset: builds the dataset 29 | Each user is expected to implement their own trainer class that inherits from this base 30 | if they want to use a new training algorithm. 31 | """ 32 | 33 | def __init__(self, config): 34 | self.config = config 35 | 36 | def step(self, *args): 37 | raise NotImplementedError("Not implemented") 38 | 39 | def loss(self, *args): 40 | raise NotImplementedError("Not implemented") 41 | 42 | def compute_rewards(self, *args): 43 | raise NotImplementedError("Not implemented") 44 | 45 | def _save_pretrained(self, save_directory): 46 | raise NotImplementedError("Not implemented") 47 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .imagebind import ImageBindWrapper 4 | from .open_clip_encoder import OpenCLIPVisionTower 5 | from .hf_vision import HFVisionTower 6 | from .siglip_encoder import SigLipVisionTower 7 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 8 | 9 | # from .eva_clip.eva_clip_encoder import EvaClipVisionTower 10 | # from .dev_eva_clip.eva_vit import EvaViTWrapper 11 | 12 | 13 | def build_vision_tower(vision_tower_cfg, **kwargs): 14 | vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) 15 | is_absolute_path_exists = os.path.exists(vision_tower) 16 | use_s2 = getattr(vision_tower_cfg, "s2", False) 17 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 18 | if use_s2: 19 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 20 | else: 21 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 22 | elif "siglip" in vision_tower: 23 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 24 | elif vision_tower.startswith("hf:"): 25 | return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 26 | elif vision_tower in ["imagebind_huge"]: 27 | return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 28 | elif vision_tower.startswith("open_clip_hub"): 29 | return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 30 | # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): 31 | # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 32 | # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: 33 | # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) 34 | 35 | raise ValueError(f"Unknown vision tower: {vision_tower}") 36 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Optional, Tuple, Union, Dict, Any 9 | import torch 10 | 11 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 12 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 13 | 14 | 15 | def _natural_key(string_): 16 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 17 | 18 | 19 | def _rescan_model_configs(): 20 | global _MODEL_CONFIGS 21 | 22 | config_ext = (".json",) 23 | config_files = [] 24 | for config_path in _MODEL_CONFIG_PATHS: 25 | if config_path.is_file() and config_path.suffix in config_ext: 26 | config_files.append(config_path) 27 | elif config_path.is_dir(): 28 | for ext in config_ext: 29 | config_files.extend(config_path.glob(f"*{ext}")) 30 | 31 | for cf in config_files: 32 | with open(cf, "r", encoding="utf8") as f: 33 | model_cfg = json.load(f) 34 | if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): 35 | _MODEL_CONFIGS[cf.stem] = model_cfg 36 | 37 | _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) 38 | 39 | 40 | _rescan_model_configs() # initial populate of model config registry 41 | 42 | 43 | def list_models(): 44 | """enumerate available model architectures based on config files""" 45 | return list(_MODEL_CONFIGS.keys()) 46 | 47 | 48 | def add_model_config(path): 49 | """add model config path or file and update registry""" 50 | if not isinstance(path, Path): 51 | path = Path(path) 52 | _MODEL_CONFIG_PATHS.append(path) 53 | _rescan_model_configs() 54 | 55 | 56 | def get_model_config(model_name): 57 | if model_name in _MODEL_CONFIGS: 58 | return deepcopy(_MODEL_CONFIGS[model_name]) 59 | else: 60 | return None 61 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava import LlavaLlamaForCausalLM 12 | 13 | 14 | def apply_delta(base_model_path, target_model_path, delta_path): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] += bparam 33 | 34 | print("Saving target model") 35 | delta.save_pretrained(target_model_path) 36 | delta_tokenizer.save_pretrained(target_model_path) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--base-model-path", type=str, required=True) 42 | parser.add_argument("--target-model-path", type=str, required=True) 43 | parser.add_argument("--delta-path", type=str, required=True) 44 | 45 | args = parser.parse_args() 46 | 47 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 48 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) 21 | worker_addr = ret.json()["address"] 22 | print(f"worker_addr: {worker_addr}") 23 | 24 | if worker_addr == "": 25 | return 26 | 27 | conv = default_conversation.copy() 28 | conv.append_message(conv.roles[0], args.message) 29 | prompt = conv.get_prompt() 30 | 31 | headers = {"User-Agent": "LLaVA Client"} 32 | pload = { 33 | "model": args.model_name, 34 | "prompt": prompt, 35 | "max_new_tokens": args.max_new_tokens, 36 | "temperature": 0.7, 37 | "stop": conv.sep, 38 | } 39 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) 40 | 41 | print(prompt.replace(conv.sep, "\n"), end="") 42 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 43 | if chunk: 44 | data = json.loads(chunk.decode("utf-8")) 45 | output = data["text"].split(conv.sep)[-1] 46 | print(output, end="\r") 47 | print("") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 53 | parser.add_argument("--worker-address", type=str) 54 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 55 | parser.add_argument("--max-new-tokens", type=int, default=32) 56 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") 57 | args = parser.parse_args() 58 | 59 | main() 60 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings", 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings", 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens", 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings", 54 | }, 55 | "pooler": "mean_pooler", 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | from .pooler_projector import PoolerProjector 6 | 7 | 8 | class IdentityMap(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x, *args, **kwargs): 13 | return x 14 | 15 | @property 16 | def config(self): 17 | return {"mm_projector_type": "identity"} 18 | 19 | 20 | class SimpleResBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.pre_norm = nn.LayerNorm(channels) 24 | 25 | self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)) 26 | 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, "mm_projector_type", "linear") 34 | 35 | if projector_type == "linear": 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | if projector_type == "pooler": 39 | return PoolerProjector(config, kwargs["vision_cfg"]) 40 | 41 | mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) 42 | if mlp_gelu_match: 43 | mlp_depth = int(mlp_gelu_match.group(1)) 44 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 45 | for _ in range(1, mlp_depth): 46 | modules.append(nn.GELU()) 47 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 48 | return nn.Sequential(*modules) 49 | 50 | mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type) 51 | if mlp_gelu_resnet_match: 52 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 53 | res_depth = int(mlp_gelu_resnet_match.group(2)) 54 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 55 | for _ in range(1, mlp_depth): 56 | modules.append(nn.GELU()) 57 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 58 | for _ in range(res_depth): 59 | modules.append(SimpleResBlock(config.hidden_size)) 60 | return nn.Sequential(*modules) 61 | 62 | if projector_type == "identity": 63 | return IdentityMap() 64 | 65 | raise ValueError(f"Unknown projector type: {projector_type}") 66 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | from llava.model.utils import auto_upgrade 12 | 13 | 14 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 15 | print("Loading base model") 16 | base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ["model.mm_projector.weight", "model.mm_projector.bias"], f"{name} not in base model" 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ["model.embed_tokens.weight", "lm_head.weight"], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" 31 | bparam = base.state_dict()[name] 32 | param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | """ 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {"height": self.image_size, "width": self.image_size} 69 | 70 | @property 71 | def size(self): 72 | return {"shortest_edge": self.image_size} 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "drivemm" 10 | version = "1.7.0.dev0" 11 | description = "DriveMM: All-in-One Large Multimodal Model for Autonomous Driving" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | 19 | [project.optional-dependencies] 20 | standalone = [ 21 | "shortuuid", 22 | "httpx==0.24.0", 23 | "einops", 24 | "ftfy", 25 | ] 26 | 27 | 28 | train = [ 29 | "llava[standalone]", 30 | "numpy==1.26.1", 31 | "open_clip_torch", 32 | "fastapi", 33 | "markdown2[all]", 34 | "numpy", 35 | "requests", 36 | "sentencepiece", 37 | "torch==2.1.2", 38 | "torchvision==0.16.2", 39 | "uvicorn", 40 | "wandb", 41 | "deepspeed==0.14.4", 42 | "peft==0.4.0", 43 | "accelerate>=0.29.1", 44 | "tokenizers~=0.19.1", 45 | "transformers==4.43.1", 46 | "bitsandbytes==0.41.0", 47 | "scikit-learn==1.2.2", 48 | "sentencepiece~=0.1.99", 49 | "einops==0.6.1", 50 | "einops-exts==0.0.4", 51 | "gradio_client==0.2.9", 52 | "urllib3<=2.0.0", 53 | "datasets==2.16.1", 54 | "pydantic==1.10.8", 55 | "timm", 56 | "hf_transfer", 57 | "opencv-python", 58 | "av", 59 | "decord", 60 | "tyro", 61 | "scipy", 62 | "flash-attn==2.6.3", 63 | ] 64 | 65 | [project.urls] 66 | "Homepage" = "https://zhijian11.github.io/DriveMM/" 67 | "Bug Tracker" = "https://github.com/zhijian11/DriveMM/issues" 68 | 69 | [tool.setuptools.packages.find] 70 | include = ["llava*", "trl*"] 71 | exclude = [ 72 | "assets*", 73 | "benchmark*", 74 | "docs", 75 | "dist*", 76 | "playground*", 77 | "scripts*", 78 | "tests*", 79 | "checkpoints*", 80 | "project_checkpoints*", 81 | "debug_checkpoints*", 82 | "mlx_configs*", 83 | "wandb*", 84 | "notebooks*", 85 | ] 86 | 87 | [tool.wheel] 88 | exclude = [ 89 | "assets*", 90 | "benchmark*", 91 | "docs", 92 | "dist*", 93 | "playground*", 94 | "scripts*", 95 | "tests*", 96 | "checkpoints*", 97 | "project_checkpoints*", 98 | "debug_checkpoints*", 99 | "mlx_configs*", 100 | "wandb*", 101 | "notebooks*", 102 | ] 103 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/imagebind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPImageProcessor 5 | 6 | try: 7 | from imagebind.models import imagebind_model 8 | from imagebind.models.imagebind_model import ModalityType 9 | from imagebind.data import load_and_transform_audio_data 10 | except ImportError: 11 | pass 12 | 13 | 14 | class ImageBindWrapper(nn.Module): 15 | def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): 16 | super().__init__() 17 | 18 | self.is_loaded = False 19 | 20 | self.vision_tower_name = vision_tower 21 | self.select_layer = select_layer 22 | self.select_feature = select_feature 23 | 24 | if not delay_load: 25 | self.load_model() 26 | 27 | def load_model(self): 28 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 29 | self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) 30 | for p in self.vision_tower.parameters(): 31 | p.requires_grad = False 32 | self.vision_tower.eval() 33 | self.is_loaded = True 34 | 35 | def train(self, mode=True): 36 | self.training = mode 37 | 38 | if self.is_loaded: 39 | self.vision_tower.eval() 40 | 41 | @torch.no_grad() 42 | def forward(self, x): 43 | if type(x) == dict: 44 | if x["audios"] is not None: 45 | inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} 46 | embeddings = self.vision_tower(inputs) 47 | audio_embedding = embeddings[ModalityType.AUDIO] 48 | return audio_embedding.unsqueeze(1) 49 | else: 50 | inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} 51 | embeddings = self.vision_tower(inputs) 52 | vision_embedding = embeddings[ModalityType.VISION] 53 | if vision_embedding.ndim == 2: 54 | return vision_embedding.unsqueeze(1) 55 | if vision_embedding.shape[1] == 257: 56 | return vision_embedding[:, 1:] 57 | raise ValueError(f"Unexpected shape: {vision_embedding.shape}") 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.modality_preprocessors.vision.cls_token.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.modality_preprocessors.vision.cls_token.device 70 | 71 | @property 72 | def hidden_size(self): 73 | return 1024 74 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import EVAEncoderWrapper 6 | from .factory import list_models, add_model_config, get_model_config 7 | 8 | from llava.utils import rank0_print 9 | 10 | 11 | class EvaClipVisionTower(nn.Module): 12 | def __init__(self, vision_tower, args, delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | self.vision_tower_name = vision_tower 17 | self.vision_tower_pretrained = args.vision_tower_pretrained 18 | self.config = get_model_config(vision_tower) 19 | 20 | if not delay_load: 21 | rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") 22 | self.load_model() 23 | elif getattr(args, "unfreeze_mm_vision_tower", False): 24 | # TODO: better detector is needed. 25 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 26 | self.load_model() 27 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 28 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 29 | self.load_model() 30 | else: 31 | self.cfg_only = self.config 32 | 33 | def load_model(self, device_map=None): 34 | rank0_print(f"Pretrained: {self.vision_tower_pretrained}") 35 | self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) 36 | self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) 37 | rank0_print(f"Loaded image processor: {self.image_processor}") 38 | self.vision_tower.requires_grad_(False) 39 | self.is_loaded = True 40 | 41 | def forward(self, images): 42 | if type(images) is list: 43 | image_features = [] 44 | for image in images: 45 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 49 | 50 | return image_features 51 | 52 | @property 53 | def dtype(self): 54 | return self.vision_tower.dtype 55 | 56 | @property 57 | def device(self): 58 | return self.vision_tower.device 59 | 60 | @property 61 | def hidden_size(self): 62 | return self.config["vision_cfg"]["width"] 63 | 64 | @property 65 | def num_patches(self): 66 | return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 67 | 68 | @property 69 | def num_patches_per_side(self): 70 | return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] 71 | 72 | @property 73 | def image_size(self): 74 | return self.config["vision_cfg"]["image_size"] 75 | -------------------------------------------------------------------------------- /scripts/summarize_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | 5 | with open("/mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA/playground/data/llava_v1_5_mix665k.json") as f: 6 | llava_v1_5_mix665k = json.load(f) # 665298 7 | 8 | with open("/mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA/playground/data/llava_instruct_150k.json") as f: 9 | llava_instruct_150k = json.load(f) # 157712 10 | 11 | # Create sets of "id" fields 12 | mix665k_ids = set() 13 | for item in llava_v1_5_mix665k: 14 | all_conv = "" 15 | for cur_conversation in item["conversations"]: 16 | all_conv += cur_conversation["value"] 17 | mix665k_ids.add(f'{item["id"]}_{all_conv}') 18 | 19 | instruct_150k_ids = set() 20 | for item in llava_instruct_150k: 21 | all_conv = "" 22 | for cur_conversation in item["conversations"]: 23 | all_conv += cur_conversation["value"] 24 | instruct_150k_ids.add(f'{item["id"]}_{all_conv}') 25 | 26 | share_gpt_ids = set() 27 | for item in llava_v1_5_mix665k: 28 | if "image" not in item: 29 | all_conv = "" 30 | for cur_conversation in item["conversations"]: 31 | all_conv += cur_conversation["value"] 32 | share_gpt_ids.add(f'{item["id"]}_{all_conv}') # 40688 33 | 34 | # Get "id" fields that are in mix665k but not in instruct_150k and share_gpt 35 | new_ids = mix665k_ids - instruct_150k_ids - share_gpt_ids # 466898 36 | 37 | # Get "id" fields that are in mix665k but not in share_gpt 38 | # new_ids = mix665k_ids - share_gpt_ids #624610 39 | 40 | # import pdb; pdb.set_trace() 41 | 42 | # Filter mix665k data based on new_ids 43 | new_data = [] 44 | for item in llava_v1_5_mix665k: 45 | all_conv = "" 46 | for cur_conversation in item["conversations"]: 47 | all_conv += cur_conversation["value"] 48 | if f'{item["id"]}_{all_conv}' in new_ids: 49 | new_data.append(item) 50 | 51 | import pdb 52 | 53 | pdb.set_trace() 54 | 55 | with open("/mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA/playground/data/mixtral_instruct_135K_of_158K_V1.5.json") as f: 56 | new_mixtral_instruct = json.load(f) 57 | 58 | # mixtral_instruct_50K_of_80K_V1.json@ 59 | 60 | # print(len(new_data)) 61 | # for _ in new_mixtral_instruct: 62 | # # import pdb; pdb.set_trace() 63 | # if "coco" not in _["image"]: 64 | # _["image"] = f"coco/train2017/{_['image']}" 65 | # new_data.append(_) 66 | 67 | # print(len(instruct_150k_ids)) 68 | print(len(new_data)) 69 | 70 | # for _ in tqdm(new_data): 71 | # if "image" in _: 72 | # if "000000442654" in _["image"]: 73 | # all_conv = "" 74 | # for cur_conversation in _["conversations"]: 75 | # all_conv += cur_conversation["value"] 76 | # # if not os.path.exists(f'/mnt/bn/vl-research/workspace/boli01/data/playground/data/{_["image"]}'): 77 | # import pdb; pdb.set_trace() 78 | 79 | # Write new_data to a new JSON file 80 | with open("/mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA/playground/data/llava_v1_5_mix665k_minus_llava_instruct_150k_minus_sharegpt_plus_mixtral_instruct_135K_of_158K_V1.5.json", "w") as f: 81 | json.dump(new_data, f) 82 | -------------------------------------------------------------------------------- /trl/trainer/model_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | from ..core import flatten_dict 5 | 6 | 7 | @dataclass 8 | class ModelConfig: 9 | """ 10 | Arguments which define the model and tokenizer to load. 11 | """ 12 | 13 | model_name_or_path: Optional[str] = field( 14 | default=None, 15 | metadata={"help": ("The model checkpoint for weights initialization.")}, 16 | ) 17 | model_revision: str = field( 18 | default="main", 19 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 20 | ) 21 | torch_dtype: Optional[str] = field( 22 | default=None, 23 | metadata={ 24 | "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), 25 | "choices": ["auto", "bfloat16", "float16", "float32"], 26 | }, 27 | ) 28 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) 29 | attn_implementation: Optional[str] = field( 30 | default=None, 31 | metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, 32 | ) 33 | use_peft: bool = field( 34 | default=False, 35 | metadata={"help": ("Whether to use PEFT or not for training.")}, 36 | ) 37 | lora_r: Optional[int] = field( 38 | default=16, 39 | metadata={"help": ("LoRA R value.")}, 40 | ) 41 | lora_alpha: Optional[int] = field( 42 | default=32, 43 | metadata={"help": ("LoRA alpha.")}, 44 | ) 45 | lora_dropout: Optional[float] = field( 46 | default=0.05, 47 | metadata={"help": ("LoRA dropout.")}, 48 | ) 49 | lora_target_modules: Optional[List[str]] = field( 50 | default=None, 51 | metadata={"help": ("LoRA target modules.")}, 52 | ) 53 | lora_modules_to_save: Optional[List[str]] = field( 54 | default=None, 55 | metadata={"help": ("Model layers to unfreeze & train")}, 56 | ) 57 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) 58 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) 59 | 60 | bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) 61 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) 62 | 63 | def to_dict(self): 64 | output_dict = {} 65 | for key, value in self.__dict__.items(): 66 | output_dict[key] = value 67 | return flatten_dict(output_dict) 68 | 69 | def __post_init__(self): 70 | if self.load_in_8bit and self.load_in_4bit: 71 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 72 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/masked_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import random 5 | 6 | 7 | class MaskedDrop(nn.Module): 8 | def __init__(self, model_args): 9 | super().__init__() 10 | 11 | self.mode = model_args.mm_mask_drop_mode 12 | self.skip_percentage = model_args.mm_mask_drop_skip_percentage 13 | self.ratio = model_args.mm_mask_drop_ratio 14 | self.ratio_upper = model_args.mm_mask_drop_ratio_upper 15 | self.ratio_lower = model_args.mm_mask_drop_ratio_lower 16 | 17 | def forward(self, image_features, *args, **kwargs): 18 | 19 | if not self.training: 20 | return image_features 21 | 22 | if self.skip_percentage > random.random(): 23 | return image_features 24 | 25 | masked_features = [] 26 | 27 | for image_feature in image_features: 28 | num_tokens = image_feature.shape[0] 29 | if self.mode == "fixed": 30 | num_keep = int(num_tokens * self.ratio) 31 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]) 32 | elif self.mode == "range": 33 | num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)) 34 | masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0]) 35 | elif self.mode == "cls_only": 36 | masked_features.append(image_feature[0:1]) 37 | else: 38 | raise ValueError(f"Unexpected masked drop mode: {self.mode}") 39 | 40 | if self.mode not in ["range"] and (type(image_features) is not list or self.mode in ["cls_only"]): 41 | masked_features = torch.stack(masked_features, dim=0) 42 | 43 | return masked_features 44 | 45 | @property 46 | def config(self): 47 | return { 48 | "mm_resampler_type": "masked_drop", 49 | "mm_mask_drop_mode": self.mode, 50 | "mm_mask_drop_skip_percentage": self.skip_percentage, 51 | "mm_mask_drop_ratio": self.ratio, 52 | "mm_mask_drop_ratio_upper": self.ratio_upper, 53 | "mm_mask_drop_ratio_lower": self.ratio_lower, 54 | } 55 | 56 | def random_masking(self, x, len_keep): 57 | """ 58 | Perform per-sample random masking by per-sample shuffling. 59 | Per-sample shuffling is done by argsort random noise. 60 | x: [N, L, D], sequence 61 | """ 62 | N, L, D = x.shape # batch, length, dim 63 | 64 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 65 | 66 | # sort noise for each sample 67 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 68 | ids_restore = torch.argsort(ids_shuffle, dim=1) 69 | 70 | # keep the first subset 71 | ids_keep = ids_shuffle[:, :len_keep] 72 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 73 | 74 | # generate the binary mask: 0 is keep, 1 is remove 75 | mask = torch.ones([N, L], device=x.device) 76 | mask[:, :len_keep] = 0 77 | # unshuffle to get the binary mask 78 | mask = torch.gather(mask, dim=1, index=ids_restore) 79 | 80 | return x_masked, mask, ids_restore 81 | -------------------------------------------------------------------------------- /trl/models/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Tuple 3 | 4 | from transformers import PreTrainedModel, PreTrainedTokenizer 5 | 6 | 7 | # TODO: Add Abstract Base Class if more formats are added 8 | @dataclass 9 | class ChatMlSpecialTokens: 10 | """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" 11 | 12 | bos_token: str = "<|im_start|>" 13 | eos_token: str = "<|im_end|>" 14 | pad_token: str = "<|im_end|>" 15 | 16 | @property 17 | def system(self): 18 | return f"{self.bos_token}system" 19 | 20 | @property 21 | def user(self): 22 | return f"{self.bos_token}user" 23 | 24 | @property 25 | def assistant(self): 26 | return f"{self.bos_token}assistant" 27 | 28 | @property 29 | def chat_template(self): 30 | return ( 31 | "{% for message in messages %}" 32 | f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" 33 | "{% endfor %}" 34 | "{% if add_generation_prompt %}" 35 | f"{{{{ '{self.assistant}\n' }}}}" 36 | "{% endif %}" 37 | ) 38 | 39 | 40 | FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} 41 | 42 | 43 | def setup_chat_format( 44 | model: PreTrainedModel, 45 | tokenizer: PreTrainedTokenizer, 46 | format: Optional[Literal["chatml"]] = "chatml", 47 | resize_to_multiple_of: Optional[int] = None, 48 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: 49 | """ 50 | Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. 51 | 52 | Args: 53 | model (`~transformers.PreTrainedModel`): The model to be modified. 54 | tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. 55 | format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". 56 | resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. 57 | Returns: 58 | model (`~transformers.PreTrainedModel`): The modified model. 59 | tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. 60 | """ 61 | # check if format available and retrieve 62 | if format not in FORMAT_MAPPING: 63 | raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") 64 | 65 | chat_format = FORMAT_MAPPING[format]() 66 | 67 | # set special tokens and them 68 | tokenizer.eos_token = chat_format.eos_token 69 | tokenizer.pad_token = chat_format.pad_token 70 | tokenizer.bos_token = chat_format.bos_token 71 | tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) 72 | # set chat format for tokenizer 73 | tokenizer.chat_template = chat_format.chat_template 74 | 75 | # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 76 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) 77 | # Make sure to update the generation config to use the new eos & bos token 78 | if getattr(model, "generation_config", None) is not None: 79 | model.generation_config.bos_token_id = tokenizer.bos_token_id 80 | model.generation_config.eos_token_id = tokenizer.eos_token_id 81 | model.generation_config.pad_token_id = tokenizer.pad_token_id 82 | 83 | return model, tokenizer 84 | -------------------------------------------------------------------------------- /trl/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | import sys 16 | 17 | 18 | if sys.version_info < (3, 8): 19 | _is_python_greater_3_8 = False 20 | else: 21 | _is_python_greater_3_8 = True 22 | 23 | 24 | def is_peft_available() -> bool: 25 | return importlib.util.find_spec("peft") is not None 26 | 27 | 28 | def is_unsloth_available() -> bool: 29 | return importlib.util.find_spec("unsloth") is not None 30 | 31 | 32 | def is_accelerate_greater_20_0() -> bool: 33 | if _is_python_greater_3_8: 34 | from importlib.metadata import version 35 | 36 | accelerate_version = version("accelerate") 37 | else: 38 | import pkg_resources 39 | 40 | accelerate_version = pkg_resources.get_distribution("accelerate").version 41 | return accelerate_version >= "0.20.0" 42 | 43 | 44 | def is_transformers_greater_than(version: str) -> bool: 45 | _transformers_version = importlib.metadata.version("transformers") 46 | return _transformers_version > version 47 | 48 | 49 | def is_torch_greater_2_0() -> bool: 50 | if _is_python_greater_3_8: 51 | from importlib.metadata import version 52 | 53 | torch_version = version("torch") 54 | else: 55 | import pkg_resources 56 | 57 | torch_version = pkg_resources.get_distribution("torch").version 58 | return torch_version >= "2.0" 59 | 60 | 61 | def is_diffusers_available() -> bool: 62 | return importlib.util.find_spec("diffusers") is not None 63 | 64 | 65 | def is_bitsandbytes_available() -> bool: 66 | import torch 67 | 68 | # bnb can be imported without GPU but is not usable. 69 | return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() 70 | 71 | 72 | def is_torchvision_available() -> bool: 73 | return importlib.util.find_spec("torchvision") is not None 74 | 75 | 76 | def is_rich_available() -> bool: 77 | return importlib.util.find_spec("rich") is not None 78 | 79 | 80 | def is_wandb_available() -> bool: 81 | return importlib.util.find_spec("wandb") is not None 82 | 83 | 84 | def is_xpu_available() -> bool: 85 | if is_accelerate_greater_20_0(): 86 | import accelerate 87 | 88 | return accelerate.utils.is_xpu_available() 89 | else: 90 | if importlib.util.find_spec("intel_extension_for_pytorch") is None: 91 | return False 92 | try: 93 | import torch 94 | 95 | return hasattr(torch, "xpu") and torch.xpu.is_available() 96 | except RuntimeError: 97 | return False 98 | 99 | 100 | def is_npu_available() -> bool: 101 | """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" 102 | if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: 103 | return False 104 | 105 | import torch 106 | import torch_npu # noqa: F401 107 | 108 | return hasattr(torch, "npu") and torch.npu.is_available() 109 | -------------------------------------------------------------------------------- /llava/train/llava_trainer_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | 4 | from llava.train.llava_trainer import LLaVATrainer 5 | 6 | 7 | class LLaVAEvalTrainer(LLaVATrainer): 8 | def evaluate(self, evaluate_args): 9 | cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ 10 | --model {evaluate_args.model} \ 11 | --model_args {evaluate_args.model_args} \ 12 | --tasks {evaluate_args.task_names} \ 13 | --batch_size {evaluate_args.batch_size} \ 14 | --log_samples_suffix {evaluate_args.log_samples_suffix} \ 15 | --output_path {evaluate_args.output_path}" 16 | if evaluate_args.limit: 17 | cmd += f" --limit {evaluate_args.limit}" 18 | if evaluate_args.num_fewshot: 19 | cmd += f" --num_fewshot {evaluate_args.num_fewshot}" 20 | if evaluate_args.gen_kwargs != "": 21 | cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" 22 | if evaluate_args.log_samples: 23 | cmd += f" --log_samples" 24 | else: 25 | assert False, "Please log samples so that the result can be parsed" 26 | results = subprocess.run([cmd], shell=True, capture_output=True, text=True) 27 | try: 28 | result_file_index_start = results.stdout.index("Saved samples to ") 29 | result_file_index_end = results.stdout.index(f".json") 30 | result_file_index_start += len("Saved samples to ") 31 | file = results.stdout[result_file_index_start:result_file_index_end] 32 | except: 33 | result_file_index_start = results.stderr.index("Saved samples to ") 34 | result_file_index_end = results.stderr.index(f".json") 35 | result_file_index_start += len("Saved samples to ") 36 | file = results.stderr[result_file_index_start:result_file_index_end] 37 | file = file.split("/")[:-1] 38 | file = "/".join(file) + "/results.json" 39 | with open(file, "r") as f: 40 | lmms_eval_results = json.load(f) 41 | result_dict = {} 42 | tasks_list = evaluate_args.task_names.split(",") 43 | for task in tasks_list: 44 | task_results = lmms_eval_results["results"][task] 45 | for k, v in task_results.items(): 46 | if k != "alias" and "stderr" not in k: 47 | metric = k.split(",")[0] 48 | result_dict[f"{task}_{metric}"] = v 49 | return result_dict 50 | 51 | """def evaluate(self, evaluate_args): 52 | initialize_tasks() 53 | tasks_list = evaluate_args.task_names.split(",") 54 | result_dict = {} 55 | results = evaluator.simple_evaluate( 56 | model=evaluate_args.model, 57 | model_args=evaluate_args.model_args, 58 | tasks=tasks_list, 59 | num_fewshot=evaluate_args.num_fewshot, 60 | batch_size=evaluate_args.batch_size, 61 | device=evaluate_args.device, 62 | limit=evaluate_args.limit, 63 | check_integrity=evaluate_args.check_integrity, 64 | show_task_to_terminal=evaluate_args.show_task_to_terminal, 65 | log_samples=evaluate_args.log_samples, 66 | gen_kwargs=evaluate_args.gen_kwargs, 67 | cli_args=evaluate_args, 68 | ) 69 | for task in tasks_list: 70 | task_results = results["results"][task] 71 | for k,v in task_results.items(): 72 | if k != "alias" and "stderr" not in k: 73 | metric = k.split(",")[0] 74 | result_dict[f"{task}_{metric}"] = v 75 | 76 | return result_dict""" 77 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 10 | 11 | 12 | class ResizeMaxSize(nn.Module): 13 | 14 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 15 | super().__init__() 16 | if not isinstance(max_size, int): 17 | raise TypeError(f"Size should be int. Got {type(max_size)}") 18 | self.max_size = max_size 19 | self.interpolation = interpolation 20 | self.fn = min if fn == "min" else min 21 | self.fill = fill 22 | 23 | def forward(self, img): 24 | if isinstance(img, torch.Tensor): 25 | height, width = img.shape[:2] 26 | else: 27 | width, height = img.size 28 | scale = self.max_size / float(max(height, width)) 29 | if scale != 1.0: 30 | new_size = tuple(round(dim * scale) for dim in (height, width)) 31 | img = F.resize(img, new_size, self.interpolation) 32 | pad_h = self.max_size - new_size[0] 33 | pad_w = self.max_size - new_size[1] 34 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 35 | return img 36 | 37 | 38 | def _convert_to_rgb(image): 39 | return image.convert("RGB") 40 | 41 | 42 | # class CatGen(nn.Module): 43 | # def __init__(self, num=4): 44 | # self.num = num 45 | # def mixgen_batch(image, text): 46 | # batch_size = image.shape[0] 47 | # index = np.random.permutation(batch_size) 48 | 49 | # cat_images = [] 50 | # for i in range(batch_size): 51 | # # image mixup 52 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 53 | # # text concat 54 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 55 | # text = torch.stack(text) 56 | # return image, text 57 | 58 | 59 | def image_transform( 60 | image_size: int, 61 | is_train: bool, 62 | mean: Optional[Tuple[float, ...]] = None, 63 | std: Optional[Tuple[float, ...]] = None, 64 | resize_longest_max: bool = False, 65 | fill_color: int = 0, 66 | ): 67 | mean = mean or OPENAI_DATASET_MEAN 68 | if not isinstance(mean, (list, tuple)): 69 | mean = (mean,) * 3 70 | 71 | std = std or OPENAI_DATASET_STD 72 | if not isinstance(std, (list, tuple)): 73 | std = (std,) * 3 74 | 75 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 76 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 77 | image_size = image_size[0] 78 | 79 | normalize = Normalize(mean=mean, std=std) 80 | if is_train: 81 | return Compose( 82 | [ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ] 88 | ) 89 | else: 90 | if resize_longest_max: 91 | transforms = [ResizeMaxSize(image_size, fill=fill_color)] 92 | else: 93 | transforms = [ 94 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 95 | CenterCrop(image_size), 96 | ] 97 | transforms.extend( 98 | [ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ] 103 | ) 104 | return Compose(transforms) 105 | -------------------------------------------------------------------------------- /trl/extras/dataset_formatting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Literal, Optional, Union 3 | 4 | from datasets import Dataset, Value 5 | from transformers import AutoTokenizer 6 | 7 | from ..trainer.utils import ConstantLengthDataset 8 | 9 | 10 | FORMAT_MAPPING = { 11 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], 12 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 13 | } 14 | 15 | 16 | def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): 17 | r""" 18 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer 19 | apply chat template to the dataset 20 | """ 21 | 22 | def format_dataset(examples): 23 | if isinstance(examples[messages_field][0], list): 24 | output_texts = [] 25 | for i in range(len(examples[messages_field])): 26 | output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) 27 | return output_texts 28 | else: 29 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) 30 | 31 | return format_dataset 32 | 33 | 34 | def instructions_formatting_function(tokenizer: AutoTokenizer): 35 | r""" 36 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer 37 | apply chat template to the dataset 38 | """ 39 | 40 | def format_dataset(examples): 41 | if isinstance(examples["prompt"], list): 42 | output_texts = [] 43 | for i in range(len(examples["prompt"])): 44 | converted_sample = [ 45 | {"role": "user", "content": examples["prompt"][i]}, 46 | {"role": "assistant", "content": examples["completion"][i]}, 47 | ] 48 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) 49 | return output_texts 50 | else: 51 | converted_sample = [ 52 | {"role": "user", "content": examples["prompt"]}, 53 | {"role": "assistant", "content": examples["completion"]}, 54 | ] 55 | return tokenizer.apply_chat_template(converted_sample, tokenize=False) 56 | 57 | return format_dataset 58 | 59 | 60 | def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]: 61 | r""" 62 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are: 63 | - `ChatML` with [{"role": str, "content": str}] 64 | - `instruction` with [{"prompt": str, "completion": str}] 65 | 66 | Args: 67 | dataset (Dataset): User dataset 68 | tokenizer (AutoTokenizer): Tokenizer used for formatting 69 | 70 | Returns: 71 | Callable: Formatting function if the dataset format is supported else None 72 | """ 73 | if isinstance(dataset, Dataset): 74 | if "messages" in dataset.features: 75 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: 76 | logging.info("Formatting dataset with chatml format") 77 | return conversations_formatting_function(tokenizer, "messages") 78 | if "conversations" in dataset.features: 79 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: 80 | logging.info("Formatting dataset with chatml format") 81 | return conversations_formatting_function(tokenizer, "conversations") 82 | elif dataset.features == FORMAT_MAPPING["instruction"]: 83 | logging.info("Formatting dataset with instruction format") 84 | return instructions_formatting_function(tokenizer) 85 | 86 | return None 87 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, MptConfig, MptForCausalLM, MptModel, GenerationConfig 21 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 22 | 23 | 24 | class LlavaMptConfig(MptConfig): 25 | model_type = "llava_mpt" 26 | 27 | 28 | class LlavaMptModel(LlavaMetaModel, MptModel): 29 | config_class = LlavaMptConfig 30 | 31 | def __init__(self, config: MptConfig): 32 | config.hidden_size = config.d_model 33 | super(LlavaMptModel, self).__init__(config) 34 | 35 | def embed_tokens(self, x): 36 | return self.wte(x) 37 | 38 | 39 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 40 | config_class = LlavaMptConfig 41 | supports_gradient_checkpointing = True 42 | 43 | def __init__(self, config): 44 | super(MptForCausalLM, self).__init__(config) 45 | 46 | config.model_type = "llava_mpt" 47 | config.rope_scaling = None 48 | self.generation_config = GenerationConfig( 49 | temperature=0.0, 50 | max_new_tokens=1024, 51 | do_sample=False, 52 | top_p=None, 53 | ) 54 | 55 | self.transformer = LlavaMptModel(config) 56 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 57 | 58 | # Initialize weights and apply final processing 59 | self.post_init() 60 | 61 | def get_model(self): 62 | return self.transformer 63 | 64 | def _set_gradient_checkpointing(self, module, value=False): 65 | if isinstance(module, LlavaMptModel): 66 | module.gradient_checkpointing = value 67 | 68 | def forward( 69 | self, 70 | input_ids: Optional[torch.LongTensor] = None, 71 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 72 | attention_mask: Optional[torch.Tensor] = None, 73 | inputs_embeds: Optional[torch.Tensor] = None, 74 | labels: Optional[torch.Tensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | cache_position=None, 80 | images=None, 81 | ): 82 | 83 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 84 | 85 | return super().forward( 86 | input_ids, 87 | past_key_values=past_key_values, 88 | attention_mask=attention_mask, 89 | inputs_embeds=inputs_embeds, 90 | labels=labels, 91 | use_cache=use_cache, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | ) 96 | 97 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 98 | images = kwargs.pop("images", None) 99 | _inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 100 | _inputs["images"] = images 101 | return _inputs 102 | 103 | 104 | AutoConfig.register("llava_mpt", LlavaMptConfig) 105 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 106 | -------------------------------------------------------------------------------- /scripts/inference_demo/demo_video.py: -------------------------------------------------------------------------------- 1 | from llava.model.builder import load_pretrained_model 2 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 3 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 4 | from llava.conversation import conv_templates, SeparatorStyle 5 | 6 | from PIL import Image 7 | import requests 8 | import copy 9 | import torch 10 | 11 | import sys 12 | import warnings 13 | 14 | warnings.filterwarnings("ignore") 15 | pretrained = "../../ckpt/DriveMM" 16 | model_name = 'llama' #get_model_name_from_path(pretrained) 17 | device = torch.device('cuda:0') 18 | llava_model_args = { 19 | "multimodal": True, 20 | } 21 | tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device, **llava_model_args) 22 | 23 | model.eval() 24 | 25 | '''lingoqa''' 26 | urls = [['lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/0.jpg', 'lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/1.jpg', 'lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/2.jpg', 'lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/3.jpg', 'lingoqa/2a469a9042a47e4c68cadfaa7bdb4519/4.jpg']] 27 | question = '.\nThere is a video of traffic captured from the front view of the ego vehicle. What is the current action and its justification? Answer in the form \"action, justification\".' 28 | urls = [['lingoqa/ab4845470b41f0e123da50c996c35745/0.jpg', 'lingoqa/ab4845470b41f0e123da50c996c35745/1.jpg', 'lingoqa/ab4845470b41f0e123da50c996c35745/2.jpg', 'lingoqa/ab4845470b41f0e123da50c996c35745/3.jpg', 'lingoqa/ab4845470b41f0e123da50c996c35745/4.jpg']] 29 | question = '.\nThere is a video of traffic captured from the front view of the ego vehicle. Is there a traffic light in the vicinity? If so, what color is it displaying?' 30 | modalities=['video'] 31 | 32 | '''bdd-x''' 33 | urls = [['bddx/1/testing_1f0fff77-a50aae97_23664_0.png', 'bddx/1/testing_1f0fff77-a50aae97_23664_3.png', 'bddx/1/testing_1f0fff77-a50aae97_23664_7.png', 'bddx/1/testing_1f0fff77-a50aae97_23664_11.png', 'bddx/1/testing_1f0fff77-a50aae97_23664_15.png']] 34 | question = ".\nThere is a video of traffic captured from the front view of the ego vehicle. Describe the current action of the ego car, and explain the cause of this car's action." 35 | urls = [['bddx/2/testing_1f13b7b2-e98c7699_23665_0.png', 'bddx/2/testing_1f13b7b2-e98c7699_23665_3.png', 'bddx/2/testing_1f13b7b2-e98c7699_23665_7.png', 'bddx/2/testing_1f13b7b2-e98c7699_23665_11.png', 'bddx/2/testing_1f13b7b2-e98c7699_23665_15.png']] 36 | question = ".\nThere is a video of traffic captured from the front view of the ego vehicle. Describe the current action of the ego car, and explain the cause of this car's action." 37 | modalities=['video'] 38 | 39 | 40 | image_tensors = [] 41 | 42 | images = [] 43 | for img_idx, cur_crls in enumerate(urls): 44 | cur_images = [] 45 | for url in cur_crls: 46 | img_pil = Image.open(str(url)).convert("RGB") 47 | cur_images.append(img_pil) 48 | images.append(cur_images) 49 | image_tensors = [process_images(cur_images, image_processor, model.config) for cur_images in images] 50 | image_tensors = [_image.to(dtype=torch.float16, device=device) for _image in image_tensors] 51 | 52 | conv_template = "llava_llama_3" 53 | conv = copy.deepcopy(conv_templates[conv_template]) 54 | conv.append_message(conv.roles[0], question) 55 | conv.append_message(conv.roles[1], None) 56 | prompt_question = conv.get_prompt() 57 | 58 | if True: 59 | print("using train_prompt! note: now only support preprocess_llama3!!!") 60 | from llava.train.train import preprocess_llama3 61 | sources = [[{"from": 'human',"value": question},{"from": 'gpt', "value": ''}]] 62 | input_ids = preprocess_llama3(sources, tokenizer, has_image=True)['input_ids'][:, :-1].to(device) 63 | else: 64 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 65 | 66 | image_sizes = [[frame.size for frame in video] for video in image_tensors] 67 | 68 | # Generate response 69 | cont = model.generate( 70 | input_ids, 71 | images=image_tensors, 72 | image_sizes=image_sizes, 73 | do_sample=False, 74 | temperature=0, 75 | max_new_tokens=4096, 76 | modalities=modalities 77 | ) 78 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) 79 | print(text_outputs[0]) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DriveMM: All-in-One Large Multimodal Model for Autonomous Driving 2 | [![arXiv](https://img.shields.io/badge/arXiv-2412.07689-b31b1b.svg?style=plastic)](https://arxiv.org/abs/2412.07689) 3 | [![Web](https://img.shields.io/badge/Web-DriveMM-blue.svg?style=plastic)](https://zhijian11.github.io/DriveMM/) 4 | [![HF](https://img.shields.io/badge/%F0%9F%A4%97-HuggingFace-yellow?style=plastic)](https://huggingface.co/DriveMM) 5 | 6 | This repository contains the implementation of the paper: 7 | 8 | > DriveMM: All-in-One Large Multimodal Model for Autonomous Driving
9 | > [Zhijian Huang](https://zhijian11.github.io/)\*, [Chenjian Feng](https://fcjian.github.io/)\*, [Feng Yan](https://scholar.google.com.hk/citations?user=gO4divAAAAAJ&hl=zh-CN&oi=sra), [Baihui Xiao](hxbh23@mails.tsinghua.edu.cn), [Zequn Jie](https://scholar.google.com/citations?user=4sKGNB0AAAAJ&hl=zh-CN&oi=ao), [Yujie Zhong](https://y-zhong.info/), [Xiaodan liang](https://lemondan.github.io/)†, [Lin Ma](http://forestlinma.com/)†
10 | > *Equal Contribution †Corresponding Authors 11 | 12 |

13 | 14 |

15 | 16 | ## :fire: Updates 17 | - **2024.12**: We release DriveMM paper on [arxiv](https://arxiv.org/abs/2412.07689)!We release the [models](https://huggingface.co/DriveMM/) and inference code! 18 | 19 | ## :sparkles: Hightlights 20 | 🔥 We propose a novel all-in-one large multimodal model, **DriveMM**, robustly equipped with the general capabilities to execute a wide range of AD tasks and the generalization ability to effectively transfer to new datasets. 21 |

22 | 23 |

24 | 🔥 We introduce comprehensive benchmarks for evaluating autonomous driving LMMs, which include six public datasets, four input types, and thirteen challenging tasks. To the best of our knowledge, this is the first to use multiple benchmarks to evaluate autonomous driving LLMs. 25 |

26 | 27 |

28 | 🔥 We present a curriculum principle for pre-training and fine-tuning on both diverse multimodal data and AD 29 | data. DriveMM demonstrates state-of-the-art performances and consistently outperforms models trained on the individual dataset across all evaluated benchmarks. 30 | 31 | 32 | ## :checkered_flag: Getting Started 33 | 34 | ### Installation 35 | 36 | #### 1. **Clone this repository and navigate to the DriveMM folder:** 37 | ```bash 38 | git clone https://github.com/zhijian11/DriveMM 39 | cd DriveMM 40 | ``` 41 | #### 2. **Install the inference package:** 42 | ```bash 43 | conda create -n drivemm python=3.10 -y 44 | conda activate drivemm 45 | pip install --upgrade pip # Enable PEP 660 support. 46 | pip install -e ".[train]" 47 | ``` 48 | #### 3. **Inference DriveMM demo:** 49 | - Download the [checkpoint](https://huggingface.co/DriveMM/DriveMM/tree/main) and put them on ckpt/ floder. 50 | ```bash 51 | cd scripts/inference_demo 52 | python demo_image.py # for image input 53 | python demo_video.py # for video input 54 | ``` 55 | ## :white_check_mark: TODO 56 | - [x] DriveMM models 57 | - [x] DriveMM inference code 58 | - [ ] DriveMM evaluation code 59 | - [ ] DriveMM training data 60 | - [ ] DriveMM training code 61 | 62 | 63 | ## :blush: Acknowledge 64 | This project has referenced some excellent open-sourced repos([LLaVa-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main)). Thanks for their wonderful works and contributions to the community. 65 | 66 | 67 | 68 | ## :pushpin: Citation 69 | If you find DriveMM is helpful for your research or applications, please consider giving us a star 🌟 and citing it by the following BibTex entry. 70 | 71 | ```bibtex 72 | @article{huang2024drivemm, 73 | title={DriveMM: All-in-One Large Multimodal Model for Autonomous Driving}, 74 | author={Huang, Zhijian and Fen, Chengjian and Yan, Feng and Xiao, Baihui and Jie, Zequn and Zhong, Yujie and Liang, Xiaodan and Ma, Lin}, 75 | journal={arXiv preprint arXiv:2412.07689}, 76 | year={2024} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | padding_mask: Optional[torch.Tensor] = None, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | if output_attentions: 27 | warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") 28 | 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) 34 | 35 | kv_seq_len = key_states.shape[-2] 36 | if past_key_value is not None: 37 | kv_seq_len += past_key_value[0].shape[-2] 38 | 39 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 40 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 41 | 42 | if past_key_value is not None: 43 | # reuse k, v 44 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 45 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 46 | 47 | past_key_value = (key_states, value_states) if use_cache else None 48 | 49 | # repeat k/v heads if n_kv_heads < n_heads 50 | key_states = repeat_kv(key_states, self.num_key_value_groups) 51 | value_states = repeat_kv(value_states, self.num_key_value_groups) 52 | 53 | # Transform the data into the format required by flash attention 54 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 55 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 56 | key_padding_mask = attention_mask 57 | 58 | if key_padding_mask is None: 59 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 61 | max_s = q_len 62 | output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 63 | output = output.view(bsz, q_len, -1) 64 | else: 65 | qkv = qkv.reshape(bsz, q_len, -1) 66 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 67 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 68 | output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) 69 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 70 | output = pad_input(output_unpad, indices, bsz, q_len) 71 | 72 | return self.o_proj(output), None, past_key_value 73 | 74 | 75 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 76 | # requires the attention mask to be the same as the key_padding_mask 77 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 78 | # [bsz, seq_len] 79 | return attention_mask 80 | 81 | 82 | def replace_llama_attn_with_flash_attn(): 83 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 84 | if cuda_major < 8: 85 | warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") 86 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 87 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 88 | -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith("http") or image_file.startswith("https"): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert("RGB") 22 | else: 23 | image = Image.open(image_file).convert("RGB") 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ("user", "assistant") 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda() 56 | 57 | while True: 58 | try: 59 | inp = input(f"{roles[0]}: ") 60 | except EOFError: 61 | inp = "" 62 | if not inp: 63 | print("exit...") 64 | break 65 | 66 | print(f"{roles[1]}: ", end="") 67 | 68 | if image is not None: 69 | # first message 70 | if model.config.mm_use_im_start_end: 71 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp 72 | else: 73 | inp = DEFAULT_IMAGE_TOKEN + "\n" + inp 74 | conv.append_message(conv.roles[0], inp) 75 | image = None 76 | else: 77 | # later messages 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 83 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 84 | keywords = [stop_str] 85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 87 | 88 | with torch.inference_mode(): 89 | output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) 90 | 91 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip() 92 | conv.messages[-1][-1] = outputs 93 | 94 | if args.debug: 95 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 101 | parser.add_argument("--model-base", type=str, default=None) 102 | parser.add_argument("--image-file", type=str, required=True) 103 | parser.add_argument("--num-gpus", type=int, default=1) 104 | parser.add_argument("--conv-mode", type=str, default=None) 105 | parser.add_argument("--temperature", type=float, default=0.2) 106 | parser.add_argument("--max-new-tokens", type=int, default=512) 107 | parser.add_argument("--load-8bit", action="store_true") 108 | parser.add_argument("--load-4bit", action="store_true") 109 | parser.add_argument("--debug", action="store_true") 110 | args = parser.parse_args() 111 | main(args) 112 | -------------------------------------------------------------------------------- /scripts/inference_demo/demo_image.py: -------------------------------------------------------------------------------- 1 | from llava.model.builder import load_pretrained_model 2 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 3 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 4 | from llava.conversation import conv_templates, SeparatorStyle 5 | 6 | from PIL import Image 7 | import requests 8 | import copy 9 | import torch 10 | 11 | import sys 12 | import warnings 13 | 14 | warnings.filterwarnings("ignore") 15 | pretrained = "../../ckpt/DriveMM" 16 | model_name = 'llama' 17 | device = torch.device('cuda:0') 18 | llava_model_args = { 19 | "multimodal": True, 20 | } 21 | tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device, **llava_model_args) 22 | 23 | model.eval() 24 | 25 | '''codalm''' 26 | urls = ['codalm/codalm1.png'] 27 | question = "\nThere is an image of traffic captured from the front view of the ego vehicle. Focus on objects influencing the ego car's driving behavior: vehicles (cars, trucks, buses, etc.), vulnerable road users (pedestrians, cyclists, motorcyclists), traffic signs (no parking, warning, directional, etc.), traffic lights (red, green, yellow), traffic cones, barriers, miscellaneous(debris, dustbin, animals, etc.). You must not discuss any objects beyond the seven categories above. Please describe each object's appearance, position, direction, and explain why it affects the ego car's behavior." 28 | urls = ['codalm/codalm2.png'] 29 | question = "\nThere is an image of traffic captured from the front view of the ego vehicle. Please describe the object inside the red rectangle in the image and explain why it affect ego car driving." 30 | urls = ['codalm/codalm3.png'] 31 | question = "\nThere is an image of traffic captured from the front view of the ego vehicle. Focus on objects influencing the ego car's driving behavior: vehicles (cars, trucks, buses, etc.), vulnerable road users (pedestrians, cyclists, motorcyclists), traffic signs (no parking, warning, directional, etc.), traffic lights (red, green, yellow), traffic cones, barriers, miscellaneous(debris, dustbin, animals, etc.). You must not discuss any objects beyond the seven categories above. Please provide driving suggestions for the ego car based on the current scene." 32 | modalities=['image'] 33 | 34 | 35 | '''drivelm''' 36 | urls = ['drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT__1535639717612404.jpg', 'drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT_LEFT__1535639717604799.jpg', 'drivelm/n008-2018-08-30-10-33-52-0400__CAM_FRONT_RIGHT__1535639717620482.jpg', 'drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK__1535639717637558.jpg', 'drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK_LEFT__1535639717647405.jpg', 'drivelm/n008-2018-08-30-10-33-52-0400__CAM_BACK_RIGHT__1535639717628113.jpg'] 37 | question = '1: 2: 3: 4: 5: 6: . These six images are the front view, front left view, front right view, back view, back left view and back right view of the ego vehicle. What are the important objects in the current scene? Those objects will be considered for the future reasoning and driving decision.' 38 | question = '1: 2: 3: 4: 5: 6: . These six images are the front view, front left view, front right view, back view, back left view and back right view of the ego vehicle. Would be in the moving direction of the ego vehicle?' 39 | modalities=['image', 'image', 'image', 'image', 'image', 'image'] 40 | 41 | images = [Image.open(str(url)).convert("RGB") for url in urls] 42 | image_tensors = process_images(images, image_processor, model.config) 43 | image_tensors = [_image.to(dtype=torch.float16, device=device) for _image in image_tensors] 44 | 45 | conv_template = "llava_llama_3" 46 | conv = copy.deepcopy(conv_templates[conv_template]) 47 | conv.append_message(conv.roles[0], question) 48 | conv.append_message(conv.roles[1], None) 49 | prompt_question = conv.get_prompt() 50 | 51 | if True: 52 | print("using train_prompt! note: now only support preprocess_llama3!!!") 53 | from llava.train.train import preprocess_llama3 54 | sources = [[{"from": 'human',"value": question},{"from": 'gpt', "value": ''}]] 55 | input_ids = preprocess_llama3(sources, tokenizer, has_image=True)['input_ids'][:, :-1].to(device) 56 | else: 57 | input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) 58 | 59 | image_sizes = [image.size for image in images] 60 | 61 | 62 | 63 | # Generate response 64 | cont = model.generate( 65 | input_ids, 66 | images=image_tensors, 67 | image_sizes=image_sizes, 68 | do_sample=False, 69 | temperature=0, 70 | max_new_tokens=4096, 71 | modalities=modalities, 72 | ) 73 | text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) 74 | print(text_outputs[0]) -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/hf_vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor 5 | from llava.utils import rank0_print 6 | 7 | 8 | class HFVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower.replace("hf:", "", 1) 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | try: 25 | self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) 26 | except Exception as e: 27 | if "448" in self.vision_tower_name: 28 | image_size = 448 29 | # use image processor with conig 30 | self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) 31 | else: 32 | self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 33 | rank0_print(f"Loaded image processor: {self.image_processor}") 34 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") 35 | self.device = self.vision_tower.device 36 | self.dtype = self.vision_tower.dtype 37 | self.config = self.vision_tower.config 38 | 39 | if hasattr(self.vision_tower, "vision_model"): 40 | self.vision_tower = self.vision_tower.vision_model 41 | self.vision_tower.requires_grad_(False) 42 | # self.vision_tower.eval() 43 | self.is_loaded = True 44 | 45 | def feature_select(self, image_forward_outs): 46 | select_feature_type = self.select_feature 47 | 48 | if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 49 | select_every_k_layer = len(image_forward_outs.hidden_states) // 4 50 | image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) 51 | select_feature_type = select_feature_type.replace("slicefour_", "") 52 | else: 53 | image_features = image_forward_outs.hidden_states[self.select_layer] 54 | 55 | if select_feature_type == "patch": 56 | image_features = image_features[:, 1:] 57 | elif select_feature_type == "cls_patch": 58 | image_features = image_features 59 | else: 60 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 61 | return image_features 62 | 63 | def forward(self, images): 64 | if type(images) is list: 65 | image_features = [] 66 | for image in images: 67 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 68 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 69 | image_features.append(image_feature) 70 | else: 71 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 72 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 73 | 74 | return image_features 75 | 76 | @property 77 | def dummy_feature(self): 78 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 79 | 80 | # @property 81 | # def dtype(self): 82 | # return self.vision_tower.dtype 83 | 84 | # @property 85 | # def device(self): 86 | # return self.vision_tower.device 87 | 88 | @property 89 | def hidden_size(self): 90 | try: 91 | _hidden_size = self.config.hidden_size 92 | except: 93 | _hidden_size = self.config.vision_config.hidden_size 94 | if "slicefour" in self.select_feature: 95 | _hidden_size *= 4 96 | return _hidden_size 97 | 98 | @property 99 | def num_patches(self): 100 | _num_patches = (self.config.image_size // self.config.patch_size) ** 2 101 | if "cls_patch" in self.select_feature: 102 | _num_patches += 1 103 | return _num_patches 104 | 105 | @property 106 | def num_patches_per_side(self): 107 | return self.config.image_size // self.config.patch_size 108 | 109 | @property 110 | def image_size(self): 111 | return self.config.image_size 112 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | import timm 14 | from timm.models.layers import Mlp, to_2tuple 15 | 16 | try: 17 | # old timm imports < 0.8.1 18 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 19 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 20 | except ImportError: 21 | # new timm imports >= 0.8.1 22 | from timm.layers import RotAttentionPool2d 23 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 24 | except ImportError: 25 | timm = None 26 | 27 | from .utils import freeze_batch_norm_2d 28 | 29 | 30 | class TimmModel(nn.Module): 31 | """timm model adapter 32 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 33 | """ 34 | 35 | def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): 36 | super().__init__() 37 | if timm is None: 38 | raise RuntimeError("Please `pip install timm` to use timm models.") 39 | 40 | self.image_size = to_2tuple(image_size) 41 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 42 | feat_size = self.trunk.default_cfg.get("pool_size", None) 43 | feature_ndim = 1 if not feat_size else 2 44 | if pool in ("abs_attn", "rot_attn"): 45 | assert feature_ndim == 2 46 | # if attn pooling used, remove both classifier and default pool 47 | self.trunk.reset_classifier(0, global_pool="") 48 | else: 49 | # reset global pool if pool config set, otherwise leave as network default 50 | reset_kwargs = dict(global_pool=pool) if pool else {} 51 | self.trunk.reset_classifier(0, **reset_kwargs) 52 | prev_chs = self.trunk.num_features 53 | 54 | head_layers = OrderedDict() 55 | if pool == "abs_attn": 56 | head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 57 | prev_chs = embed_dim 58 | elif pool == "rot_attn": 59 | head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 60 | prev_chs = embed_dim 61 | else: 62 | assert proj, "projection layer needed if non-attention pooling is used." 63 | 64 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 65 | if proj == "linear": 66 | head_layers["drop"] = nn.Dropout(drop) 67 | head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 68 | elif proj == "mlp": 69 | head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 70 | 71 | self.head = nn.Sequential(head_layers) 72 | 73 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 74 | """lock modules 75 | Args: 76 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 77 | """ 78 | if not unlocked_groups: 79 | # lock full model 80 | for param in self.trunk.parameters(): 81 | param.requires_grad = False 82 | if freeze_bn_stats: 83 | freeze_batch_norm_2d(self.trunk) 84 | else: 85 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 86 | try: 87 | # FIXME import here until API stable and in an official release 88 | from timm.models.helpers import group_parameters, group_modules 89 | except ImportError: 90 | raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") 91 | matcher = self.trunk.group_matcher() 92 | gparams = group_parameters(self.trunk, matcher) 93 | max_layer_id = max(gparams.keys()) 94 | max_layer_id = max_layer_id - unlocked_groups 95 | for group_idx in range(max_layer_id + 1): 96 | group = gparams[group_idx] 97 | for param in group: 98 | self.trunk.get_parameter(param).requires_grad = False 99 | if freeze_bn_stats: 100 | gmodules = group_modules(self.trunk, matcher, reverse=True) 101 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 102 | freeze_batch_norm_2d(self.trunk, gmodules) 103 | 104 | @torch.jit.ignore 105 | def set_grad_checkpointing(self, enable=True): 106 | try: 107 | self.trunk.set_grad_checkpointing(enable) 108 | except Exception as e: 109 | logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") 110 | 111 | def forward(self, x): 112 | x = self.trunk(x) 113 | x = self.head(x) 114 | return x 115 | -------------------------------------------------------------------------------- /trl/trainer/ddpo_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from dataclasses import dataclass, field 5 | from typing import Literal, Optional 6 | 7 | from ..core import flatten_dict 8 | from ..import_utils import is_bitsandbytes_available, is_torchvision_available 9 | 10 | 11 | @dataclass 12 | class DDPOConfig: 13 | """ 14 | Configuration class for DDPOTrainer 15 | """ 16 | 17 | # common parameters 18 | exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] 19 | """the name of this experiment (by default is the file name without the extension name)""" 20 | run_name: Optional[str] = "" 21 | """Run name for wandb logging and checkpoint saving.""" 22 | seed: int = 0 23 | """Seed value for random generations""" 24 | log_with: Optional[Literal["wandb", "tensorboard"]] = None 25 | """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" 26 | tracker_kwargs: dict = field(default_factory=dict) 27 | """Keyword arguments for the tracker (e.g. wandb_project)""" 28 | accelerator_kwargs: dict = field(default_factory=dict) 29 | """Keyword arguments for the accelerator""" 30 | project_kwargs: dict = field(default_factory=dict) 31 | """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" 32 | tracker_project_name: str = "trl" 33 | """Name of project to use for tracking""" 34 | logdir: str = "logs" 35 | """Top-level logging directory for checkpoint saving.""" 36 | 37 | # hyperparameters 38 | num_epochs: int = 100 39 | """Number of epochs to train.""" 40 | save_freq: int = 1 41 | """Number of epochs between saving model checkpoints.""" 42 | num_checkpoint_limit: int = 5 43 | """Number of checkpoints to keep before overwriting old ones.""" 44 | mixed_precision: str = "fp16" 45 | """Mixed precision training.""" 46 | allow_tf32: bool = True 47 | """Allow tf32 on Ampere GPUs.""" 48 | resume_from: Optional[str] = "" 49 | """Resume training from a checkpoint.""" 50 | sample_num_steps: int = 50 51 | """Number of sampler inference steps.""" 52 | sample_eta: float = 1.0 53 | """Eta parameter for the DDIM sampler.""" 54 | sample_guidance_scale: float = 5.0 55 | """Classifier-free guidance weight.""" 56 | sample_batch_size: int = 1 57 | """Batch size (per GPU!) to use for sampling.""" 58 | sample_num_batches_per_epoch: int = 2 59 | """Number of batches to sample per epoch.""" 60 | train_batch_size: int = 1 61 | """Batch size (per GPU!) to use for training.""" 62 | train_use_8bit_adam: bool = False 63 | """Whether to use the 8bit Adam optimizer from bitsandbytes.""" 64 | train_learning_rate: float = 3e-4 65 | """Learning rate.""" 66 | train_adam_beta1: float = 0.9 67 | """Adam beta1.""" 68 | train_adam_beta2: float = 0.999 69 | """Adam beta2.""" 70 | train_adam_weight_decay: float = 1e-4 71 | """Adam weight decay.""" 72 | train_adam_epsilon: float = 1e-8 73 | """Adam epsilon.""" 74 | train_gradient_accumulation_steps: int = 1 75 | """Number of gradient accumulation steps.""" 76 | train_max_grad_norm: float = 1.0 77 | """Maximum gradient norm for gradient clipping.""" 78 | train_num_inner_epochs: int = 1 79 | """Number of inner epochs per outer epoch.""" 80 | train_cfg: bool = True 81 | """Whether or not to use classifier-free guidance during training.""" 82 | train_adv_clip_max: float = 5 83 | """Clip advantages to the range.""" 84 | train_clip_range: float = 1e-4 85 | """The PPO clip range.""" 86 | train_timestep_fraction: float = 1.0 87 | """The fraction of timesteps to train on.""" 88 | per_prompt_stat_tracking: bool = False 89 | """Whether to track statistics for each prompt separately.""" 90 | per_prompt_stat_tracking_buffer_size: int = 16 91 | """Number of reward values to store in the buffer for each prompt.""" 92 | per_prompt_stat_tracking_min_count: int = 16 93 | """The minimum number of reward values to store in the buffer.""" 94 | async_reward_computation: bool = False 95 | """Whether to compute rewards asynchronously.""" 96 | max_workers: int = 2 97 | """The maximum number of workers to use for async reward computation.""" 98 | negative_prompts: Optional[str] = "" 99 | """Comma-separated list of prompts to use as negative examples.""" 100 | 101 | def to_dict(self): 102 | output_dict = {} 103 | for key, value in self.__dict__.items(): 104 | output_dict[key] = value 105 | return flatten_dict(output_dict) 106 | 107 | def __post_init__(self): 108 | if self.log_with not in ["wandb", "tensorboard"]: 109 | warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.")) 110 | 111 | if self.log_with == "wandb" and not is_torchvision_available(): 112 | warnings.warn("Wandb image logging requires torchvision to be installed") 113 | 114 | if self.train_use_8bit_adam and not is_bitsandbytes_available(): 115 | raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.") 116 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Duc Q. Nguyen, Haotian Liu and Bo Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaGemmaConfig(GemmaConfig): 31 | model_type = "llava_gemma" 32 | 33 | 34 | class LlavaGemmaModel(LlavaMetaModel, GemmaModel): 35 | config_class = LlavaGemmaConfig 36 | 37 | def __init__(self, config: GemmaConfig): 38 | super(LlavaGemmaModel, self).__init__(config) 39 | 40 | 41 | class LlavaGemmaForCausalLM(GemmaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaGemmaConfig 43 | 44 | def __init__(self, config): 45 | super(GemmaForCausalLM, self).__init__(config) 46 | self.model = LlavaGemmaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | image_sizes: Optional[List[List[int]]] = None, 69 | return_dict: Optional[bool] = None, 70 | cache_position: Optional[torch.LongTensor] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 75 | 76 | return super().forward( 77 | input_ids=input_ids, 78 | attention_mask=attention_mask, 79 | position_ids=position_ids, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | labels=labels, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict, 87 | cache_position=cache_position, 88 | ) 89 | 90 | @torch.no_grad() 91 | def generate( 92 | self, 93 | inputs: Optional[torch.Tensor] = None, 94 | images: Optional[torch.Tensor] = None, 95 | image_sizes: Optional[torch.Tensor] = None, 96 | **kwargs, 97 | ) -> Union[GenerateOutput, torch.LongTensor]: 98 | position_ids = kwargs.pop("position_ids", None) 99 | attention_mask = kwargs.pop("attention_mask", None) 100 | if "inputs_embeds" in kwargs: 101 | raise NotImplementedError("`inputs_embeds` is not supported") 102 | 103 | if images is not None: 104 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 105 | else: 106 | inputs_embeds = self.get_model().embed_tokens(inputs) 107 | 108 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 109 | 110 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 111 | images = kwargs.pop("images", None) 112 | image_sizes = kwargs.pop("image_sizes", None) 113 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 114 | if images is not None: 115 | inputs["images"] = images 116 | if image_sizes is not None: 117 | inputs["image_sizes"] = image_sizes 118 | return inputs 119 | 120 | 121 | AutoConfig.register("llava_gemma", LlavaGemmaConfig) 122 | AutoModelForCausalLM.register(LlavaGemmaConfig, LlavaGemmaForCausalLM) 123 | -------------------------------------------------------------------------------- /llava/model/multimodal_resampler/perceiver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | try: 9 | from einops_exts import rearrange_many 10 | except: 11 | pass 12 | 13 | from torch import einsum, nn 14 | 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | 20 | def FeedForward(dim, mult=4): 21 | inner_dim = int(dim * mult) 22 | return nn.Sequential( 23 | nn.LayerNorm(dim), 24 | nn.Linear(dim, inner_dim, bias=False), 25 | nn.GELU(), 26 | nn.Linear(inner_dim, dim, bias=False), 27 | ) 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.heads = heads 35 | inner_dim = dim_head * heads 36 | 37 | self.norm_media = nn.LayerNorm(dim) 38 | self.norm_latents = nn.LayerNorm(dim) 39 | 40 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 41 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 43 | 44 | def forward(self, x, latents): 45 | """ 46 | Args: 47 | x (torch.Tensor): image features 48 | shape (b, T, n1, D) 49 | latent (torch.Tensor): latent features 50 | shape (b, T, n2, D) 51 | """ 52 | x = self.norm_media(x) 53 | latents = self.norm_latents(latents) 54 | 55 | h = self.heads 56 | 57 | q = self.to_q(latents) 58 | kv_input = torch.cat((x, latents), dim=-2) 59 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 60 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 61 | q = q * self.scale 62 | 63 | # attention 64 | sim = einsum("... i d, ... j d -> ... i j", q, k) 65 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 66 | attn = sim.softmax(dim=-1) 67 | 68 | out = einsum("... i j, ... j d -> ... i d", attn, v) 69 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class PerceiverResamplerModule(nn.Module): 74 | def __init__( 75 | self, 76 | *, 77 | dim, 78 | depth=6, 79 | dim_head=64, 80 | heads=8, 81 | num_latents=64, 82 | max_num_media=None, 83 | max_num_frames=None, 84 | ff_mult=4, 85 | ): 86 | super().__init__() 87 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 88 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 89 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 90 | 91 | self.layers = nn.ModuleList([]) 92 | for _ in range(depth): 93 | self.layers.append( 94 | nn.ModuleList( 95 | [ 96 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 97 | FeedForward(dim=dim, mult=ff_mult) if ff_mult > 0 else nn.Identity(), 98 | ] 99 | ) 100 | ) 101 | 102 | self.norm = nn.LayerNorm(dim) 103 | 104 | def forward(self, x): 105 | """ 106 | Args: 107 | x (torch.Tensor): image features 108 | shape (b, T, F, v, D) 109 | Returns: 110 | shape (b, T, n, D) where n is self.num_latents 111 | """ 112 | b, T, F, v = x.shape[:4] 113 | 114 | # frame and media time embeddings 115 | if exists(self.frame_embs): 116 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 117 | x = x + frame_embs 118 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 119 | if exists(self.media_time_embs): 120 | x = x + self.media_time_embs[:T] 121 | 122 | # blocks 123 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 124 | for attn, ff in self.layers: 125 | latents = attn(x, latents) + latents 126 | latents = ff(latents) + latents 127 | return self.norm(latents) 128 | 129 | 130 | class PerceiverResampler(nn.Module): 131 | def __init__(self, model_args, vision_tower): 132 | super().__init__() 133 | 134 | self.depth = model_args.mm_perceiver_depth 135 | self.num_latents = model_args.mm_perceiver_latents 136 | self.ff_mult = model_args.mm_perceiver_ff_mult 137 | self.pretrained = model_args.mm_perceiver_pretrained 138 | 139 | self.perceiver = PerceiverResamplerModule(dim=vision_tower.hidden_size, depth=self.depth, num_latents=self.num_latents, ff_mult=self.ff_mult) 140 | 141 | if self.pretrained is not None: 142 | self.load_state_dict(torch.load(self.pretrained)) 143 | 144 | def forward(self, image_features, *args, **kwargs): 145 | return self.perceiver(image_features[:, None, None]).squeeze(1) 146 | 147 | @property 148 | def config(self): 149 | return { 150 | "mm_resampler_type": "perceiver", 151 | "mm_perceiver_depth": self.depth, 152 | "mm_perceiver_latents": self.num_latents, 153 | "mm_perceiver_ff_mult": self.ff_mult, 154 | "mm_perceiver_pretrained": self.pretrained, 155 | } 156 | -------------------------------------------------------------------------------- /trl/extras/best_of_n_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Union 2 | 3 | import torch 4 | from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast 5 | 6 | from ..core import set_seed 7 | from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper 8 | 9 | 10 | class BestOfNSampler(object): 11 | def __init__( 12 | self, 13 | model: PreTrainedModelWrapper, 14 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 15 | queries_to_scores: Callable[[List[str]], List[float]], 16 | length_sampler: Any, 17 | sample_size: int = 4, 18 | seed: Optional[int] = None, 19 | n_candidates: int = 1, 20 | generation_config: Optional[GenerationConfig] = None, 21 | ) -> None: 22 | r""" 23 | Initialize the sampler for best-of-n generation 24 | 25 | Args: 26 | model (`PreTrainedModelWrapper`): 27 | The pretrained model to use for generation 28 | tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): 29 | Tokenizer associated with the pretrained model 30 | queries_to_scores (`Callable[[List[str]], List[float]]`): 31 | Callable that takes a list of generated texts and returns the associated reward scores 32 | length_sampler (`Any`): 33 | Sampler used to sample the length of the generated text 34 | sample_size (`int`): 35 | Number of samples to generate for each query 36 | seed (`int`, *optional*): 37 | Random seed used to control generation 38 | n_candidates (`int`): 39 | Number of candidates to return for each query 40 | generation_config (`GenerationConfig`, *optional*): 41 | Generation config passed to the underlying model's `generate` method. 42 | See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details 43 | """ 44 | if seed is not None: 45 | set_seed(seed) 46 | 47 | if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): 48 | raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}") 49 | if not isinstance(model, (SUPPORTED_ARCHITECTURES)): 50 | raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") 51 | 52 | self.model = model 53 | self.tokenizer = tokenizer 54 | 55 | self.queries_to_scores = queries_to_scores 56 | self.length_sampler = length_sampler 57 | self.gen_config = generation_config 58 | self.sample_size = sample_size 59 | self.n_candidates = n_candidates 60 | 61 | def generate( 62 | self, 63 | tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], 64 | skip_special_tokens: bool = True, 65 | device: Optional[Union[str, torch.device]] = None, 66 | **generation_kwargs, 67 | ) -> List[List[str]]: 68 | r""" 69 | Generate the best of n samples for input queries 70 | 71 | Args: 72 | tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): 73 | represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) 74 | skip_special_tokens (`bool`): 75 | Whether to remove the special tokens from the output 76 | device (`str` or `torch.device`, *optional*): 77 | The device on which the model will be loaded 78 | **generation_kwargs (`dict`, *optional*): 79 | Additional keyword arguments passed along to the underlying model's `generate` method. 80 | This is used to override generation config 81 | 82 | Returns: 83 | List[List[str]]: A list of lists of generated texts 84 | """ 85 | queries = None 86 | 87 | if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: 88 | queries = tokenized_query.unsqueeze(0) 89 | elif isinstance(tokenized_query, List): 90 | element_type = type(tokenized_query[0]) 91 | if element_type == int: 92 | queries = torch.tensor(tokenized_query).unsqueeze(0) 93 | elif element_type == torch.Tensor: 94 | queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] 95 | else: 96 | queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] 97 | 98 | result = [] 99 | 100 | for query in queries: 101 | queries = query.repeat((self.sample_size, 1)) 102 | output = self.model.generate( 103 | queries.to(device), 104 | max_new_tokens=self.length_sampler(), 105 | generation_config=self.gen_config, 106 | **generation_kwargs, 107 | ).squeeze() 108 | output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) 109 | scores = torch.tensor(self.queries_to_scores(output)) 110 | output = [output[i] for i in scores.topk(self.n_candidates).indices] 111 | result.append(output) 112 | 113 | return result 114 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MistralConfig, MistralModel, MistralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMistralConfig(MistralConfig): 31 | model_type = "llava_mistral" 32 | temperature: float = 0.0 # reset to 0.0, previously 0.9 for Vicuna 33 | max_new_tokens: int = 1024 34 | do_sample: bool = False 35 | top_p: Optional[float] = None 36 | 37 | 38 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 39 | config_class = LlavaMistralConfig 40 | 41 | def __init__(self, config: MistralConfig): 42 | super(LlavaMistralModel, self).__init__(config) 43 | 44 | 45 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMistralConfig 47 | 48 | def __init__(self, config): 49 | super(MistralForCausalLM, self).__init__(config) 50 | 51 | config.model_type = "llava_mistral" 52 | config.rope_scaling = None 53 | 54 | self.model = LlavaMistralModel(config) 55 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | image_sizes: Optional[List[List[int]]] = None, 75 | return_dict: Optional[bool] = None, 76 | cache_position=None, 77 | ) -> Union[Tuple, CausalLMOutputWithPast]: 78 | 79 | if inputs_embeds is None: 80 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) 81 | 82 | return super().forward( 83 | input_ids=input_ids, 84 | attention_mask=attention_mask, 85 | position_ids=position_ids, 86 | past_key_values=past_key_values, 87 | inputs_embeds=inputs_embeds, 88 | labels=labels, 89 | use_cache=use_cache, 90 | output_attentions=output_attentions, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | 95 | @torch.no_grad() 96 | def generate( 97 | self, 98 | inputs: Optional[torch.Tensor] = None, 99 | images: Optional[torch.Tensor] = None, 100 | image_sizes: Optional[torch.Tensor] = None, 101 | **kwargs, 102 | ) -> Union[GenerateOutput, torch.LongTensor]: 103 | position_ids = kwargs.pop("position_ids", None) 104 | attention_mask = kwargs.pop("attention_mask", None) 105 | if "inputs_embeds" in kwargs: 106 | raise NotImplementedError("`inputs_embeds` is not supported") 107 | 108 | if images is not None: 109 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) 110 | else: 111 | inputs_embeds = self.get_model().embed_tokens(inputs) 112 | 113 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 114 | 115 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 116 | images = kwargs.pop("images", None) 117 | image_sizes = kwargs.pop("image_sizes", None) 118 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 119 | if images is not None: 120 | inputs["images"] = images 121 | if image_sizes is not None: 122 | inputs["image_sizes"] = image_sizes 123 | return inputs 124 | 125 | 126 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 127 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 128 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | 8 | def broadcat(tensors, dim=-1): 9 | num_tensors = len(tensors) 10 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 11 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" 12 | shape_len = list(shape_lens)[0] 13 | dim = (dim + shape_len) if dim < 0 else dim 14 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 15 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 16 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" 17 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 18 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 19 | expanded_dims.insert(dim, (dim, dims[dim])) 20 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 21 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 22 | return torch.cat(tensors, dim=dim) 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, "... (d r) -> ... d r", r=2) 27 | x1, x2 = x.unbind(dim=-1) 28 | x = torch.stack((-x2, x1), dim=-1) 29 | return rearrange(x, "... d r -> ... (d r)") 30 | 31 | 32 | class VisionRotaryEmbedding(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | pt_seq_len, 37 | ft_seq_len=None, 38 | custom_freqs=None, 39 | freqs_for="lang", 40 | theta=10000, 41 | max_freq=10, 42 | num_freqs=1, 43 | ): 44 | super().__init__() 45 | if custom_freqs: 46 | freqs = custom_freqs 47 | elif freqs_for == "lang": 48 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 49 | elif freqs_for == "pixel": 50 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 51 | elif freqs_for == "constant": 52 | freqs = torch.ones(num_freqs).float() 53 | else: 54 | raise ValueError(f"unknown modality {freqs_for}") 55 | 56 | if ft_seq_len is None: 57 | ft_seq_len = pt_seq_len 58 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 59 | 60 | freqs_h = torch.einsum("..., f -> ... f", t, freqs) 61 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) 62 | 63 | freqs_w = torch.einsum("..., f -> ... f", t, freqs) 64 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) 65 | 66 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) 67 | 68 | self.register_buffer("freqs_cos", freqs.cos()) 69 | self.register_buffer("freqs_sin", freqs.sin()) 70 | 71 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 72 | 73 | def forward(self, t, start_index=0): 74 | rot_dim = self.freqs_cos.shape[-1] 75 | end_index = start_index + rot_dim 76 | assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 77 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 78 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 79 | 80 | return torch.cat((t_left, t, t_right), dim=-1) 81 | 82 | 83 | class VisionRotaryEmbeddingFast(nn.Module): 84 | def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): 85 | super().__init__() 86 | if custom_freqs: 87 | freqs = custom_freqs 88 | elif freqs_for == "lang": 89 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 90 | elif freqs_for == "pixel": 91 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 92 | elif freqs_for == "constant": 93 | freqs = torch.ones(num_freqs).float() 94 | else: 95 | raise ValueError(f"unknown modality {freqs_for}") 96 | 97 | if ft_seq_len is None: 98 | ft_seq_len = pt_seq_len 99 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 100 | 101 | freqs = torch.einsum("..., f -> ... f", t, freqs) 102 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 103 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 104 | 105 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 106 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 107 | 108 | self.patch_dropout = patch_dropout 109 | 110 | self.register_buffer("freqs_cos", freqs_cos) 111 | self.register_buffer("freqs_sin", freqs_sin) 112 | 113 | logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") 114 | 115 | def forward(self, t, patch_indices_keep=None): 116 | if patch_indices_keep is not None: 117 | batch = t.size()[0] 118 | batch_indices = torch.arange(batch) 119 | batch_indices = batch_indices[..., None] 120 | 121 | freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 122 | freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) 123 | 124 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 125 | freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") 126 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 127 | freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") 128 | 129 | return t * freqs_cos + rotate_half(t) * freqs_sin 130 | 131 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 132 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag("openai") 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = "fp32" if device == "cpu" else "fp16" 56 | 57 | if get_pretrained_url(name, "openai"): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith("amp") or precision == "fp32": 87 | model.float() 88 | elif precision == "bf16": 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == "fp32": 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | 10 | has_distributed = True 11 | except ImportError: 12 | has_distributed = False 13 | 14 | try: 15 | import horovod.torch as hvd 16 | except ImportError: 17 | hvd = None 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy 20 | 21 | 22 | def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): 23 | assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." 24 | if use_horovod: 25 | assert hvd is not None, "Please install horovod" 26 | if gather_with_grad: 27 | all_image_features = hvd.allgather(image_features) 28 | all_text_features = hvd.allgather(text_features) 29 | else: 30 | with torch.no_grad(): 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | if not local_loss: 34 | # ensure grads for local rank when all_* features don't have a gradient 35 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 36 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 37 | gathered_image_features[rank] = image_features 38 | gathered_text_features[rank] = text_features 39 | all_image_features = torch.cat(gathered_image_features, dim=0) 40 | all_text_features = torch.cat(gathered_text_features, dim=0) 41 | else: 42 | # We gather tensors from all gpus 43 | if gather_with_grad: 44 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 45 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 46 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 47 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 48 | else: 49 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 50 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 51 | dist.all_gather(gathered_image_features, image_features) 52 | dist.all_gather(gathered_text_features, text_features) 53 | if not local_loss: 54 | # ensure grads for local rank when all_* features don't have a gradient 55 | gathered_image_features[rank] = image_features 56 | gathered_text_features[rank] = text_features 57 | all_image_features = torch.cat(gathered_image_features, dim=0) 58 | all_text_features = torch.cat(gathered_text_features, dim=0) 59 | 60 | return all_image_features, all_text_features 61 | 62 | 63 | class ClipLoss(nn.Module): 64 | 65 | def __init__( 66 | self, 67 | local_loss=False, 68 | gather_with_grad=False, 69 | cache_labels=False, 70 | rank=0, 71 | world_size=1, 72 | use_horovod=False, 73 | smoothing=0.0, 74 | ): 75 | super().__init__() 76 | self.local_loss = local_loss 77 | self.gather_with_grad = gather_with_grad 78 | self.cache_labels = cache_labels 79 | self.rank = rank 80 | self.world_size = world_size 81 | self.use_horovod = use_horovod 82 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale=1.0): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 92 | 93 | if self.local_loss: 94 | logits_per_image = logit_scale * image_features @ all_text_features.T 95 | logits_per_text = logit_scale * text_features @ all_image_features.T 96 | else: 97 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 98 | logits_per_text = logits_per_image.T 99 | else: 100 | logits_per_image = logit_scale * image_features @ text_features.T 101 | logits_per_text = logit_scale * text_features @ image_features.T 102 | # calculated ground-truth and cache if enabled 103 | num_logits = logits_per_image.shape[0] 104 | if self.prev_num_logits != num_logits or device not in self.labels: 105 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 106 | if self.world_size > 1 and self.local_loss: 107 | labels = labels + num_logits * self.rank 108 | if self.cache_labels: 109 | self.labels[device] = labels 110 | self.prev_num_logits = num_logits 111 | else: 112 | labels = self.labels[device] 113 | 114 | if self.label_smoothing_cross_entropy: 115 | total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 116 | else: 117 | total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 118 | 119 | acc = None 120 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 121 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 122 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 123 | return total_loss, acc 124 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, MixtralConfig, MixtralModel, MixtralForCausalLM, GenerationConfig 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMixtralConfig(MixtralConfig): 31 | model_type = "llava_mixtral" 32 | 33 | 34 | class LlavaMixtralModel(LlavaMetaModel, MixtralModel): 35 | config_class = LlavaMixtralConfig 36 | 37 | def __init__(self, config: MixtralConfig): 38 | super(LlavaMixtralModel, self).__init__(config) 39 | 40 | 41 | class LlavaMixtralForCausalLM(MixtralForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaMixtralConfig 43 | 44 | def __init__(self, config): 45 | super(MixtralForCausalLM, self).__init__(config) 46 | 47 | config.model_type = "llava_mixtral" 48 | config.rope_scaling = None 49 | self.model = LlavaMixtralModel(config) 50 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | modalities: Optional[List[str]] = ["image"], 72 | dpo_forward: Optional[bool] = None, 73 | cache_position=None, 74 | ) -> Union[Tuple, CausalLMOutputWithPast]: 75 | 76 | if inputs_embeds is None: 77 | (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) 78 | 79 | if dpo_forward: 80 | outputs = self.model( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask, 83 | position_ids=position_ids, 84 | past_key_values=past_key_values, 85 | inputs_embeds=inputs_embeds, 86 | use_cache=use_cache, 87 | output_attentions=output_attentions, 88 | output_hidden_states=output_hidden_states, 89 | return_dict=return_dict, 90 | ) 91 | 92 | hidden_states = outputs[0] 93 | logits = self.lm_head(hidden_states) 94 | return logits, labels 95 | 96 | else: 97 | return super().forward( 98 | input_ids=input_ids, 99 | attention_mask=attention_mask, 100 | position_ids=position_ids, 101 | past_key_values=past_key_values, 102 | inputs_embeds=inputs_embeds, 103 | labels=labels, 104 | use_cache=use_cache, 105 | output_attentions=output_attentions, 106 | output_hidden_states=output_hidden_states, 107 | return_dict=return_dict, 108 | ) 109 | 110 | @torch.no_grad() 111 | def generate( 112 | self, 113 | inputs: Optional[torch.Tensor] = None, 114 | images: Optional[torch.Tensor] = None, 115 | image_sizes: Optional[torch.Tensor] = None, 116 | modalities: Optional[List[str]] = ["image"], 117 | **kwargs, 118 | ) -> Union[GenerateOutput, torch.LongTensor]: 119 | position_ids = kwargs.pop("position_ids", None) 120 | attention_mask = kwargs.pop("attention_mask", None) 121 | if "inputs_embeds" in kwargs: 122 | raise NotImplementedError("`inputs_embeds` is not supported") 123 | 124 | if images is not None: 125 | (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) 126 | else: 127 | inputs_embeds = self.get_model().embed_tokens(inputs) 128 | 129 | return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) 130 | 131 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 132 | images = kwargs.pop("images", None) 133 | image_sizes = kwargs.pop("image_sizes", None) 134 | inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) 135 | if images is not None: 136 | inputs["images"] = images 137 | if image_sizes is not None: 138 | inputs["image_sizes"] = image_sizes 139 | return inputs 140 | 141 | 142 | AutoConfig.register("llava_mixtral", LlavaMixtralConfig) 143 | AutoModelForCausalLM.register(LlavaMixtralConfig, LlavaMixtralForCausalLM) 144 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 4 | from llava.conversation import conv_templates, SeparatorStyle 5 | from llava.model.builder import load_pretrained_model 6 | from llava.utils import disable_torch_init 7 | from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria 8 | from transformers.generation.streamers import TextIteratorStreamer 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from io import BytesIO 14 | 15 | from cog import BasePredictor, Input, Path, ConcatenateIterator 16 | import time 17 | import subprocess 18 | from threading import Thread 19 | 20 | import os 21 | os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" 22 | 23 | # url for the weights mirror 24 | REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default" 25 | # files to download from the weights mirrors 26 | weights = [ 27 | { 28 | "dest": "liuhaotian/llava-v1.5-13b", 29 | # git commit hash from huggingface 30 | "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8", 31 | "files": [ 32 | "config.json", 33 | "generation_config.json", 34 | "pytorch_model-00001-of-00003.bin", 35 | "pytorch_model-00002-of-00003.bin", 36 | "pytorch_model-00003-of-00003.bin", 37 | "pytorch_model.bin.index.json", 38 | "special_tokens_map.json", 39 | "tokenizer.model", 40 | "tokenizer_config.json", 41 | ], 42 | }, 43 | { 44 | "dest": "openai/clip-vit-large-patch14-336", 45 | "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", 46 | "files": ["config.json", "preprocessor_config.json", "pytorch_model.bin"], 47 | }, 48 | ] 49 | 50 | 51 | def download_json(url: str, dest: Path): 52 | res = requests.get(url, allow_redirects=True) 53 | if res.status_code == 200 and res.content: 54 | with dest.open("wb") as f: 55 | f.write(res.content) 56 | else: 57 | print(f"Failed to download {url}. Status code: {res.status_code}") 58 | 59 | def download_weights(baseurl: str, basedest: str, files: list[str]): 60 | basedest = Path(basedest) 61 | start = time.time() 62 | print("downloading to: ", basedest) 63 | basedest.mkdir(parents=True, exist_ok=True) 64 | for f in files: 65 | dest = basedest / f 66 | url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f) 67 | if not dest.exists(): 68 | print("downloading url: ", url) 69 | if dest.suffix == ".json": 70 | download_json(url, dest) 71 | else: 72 | subprocess.check_call(["pget", url, str(dest)], close_fds=False) 73 | print("downloading took: ", time.time() - start) 74 | 75 | class Predictor(BasePredictor): 76 | def setup(self) -> None: 77 | """Load the model into memory to make running multiple predictions efficient""" 78 | for weight in weights: 79 | download_weights(weight["src"], weight["dest"], weight["files"]) 80 | disable_torch_init() 81 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) 82 | 83 | def predict( 84 | self, 85 | image: Path = Input(description="Input image"), 86 | prompt: str = Input(description="Prompt to use for text generation"), 87 | top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0), 88 | temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0), 89 | max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0), 90 | ) -> ConcatenateIterator[str]: 91 | """Run a single prediction on the model""" 92 | 93 | conv_mode = "llava_v1" 94 | conv = conv_templates[conv_mode].copy() 95 | 96 | image_data = load_image(str(image)) 97 | image_tensor = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"].half().cuda() 98 | 99 | # loop start 100 | 101 | # just one turn, always prepend image token 102 | inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt 103 | conv.append_message(conv.roles[0], inp) 104 | 105 | conv.append_message(conv.roles[1], None) 106 | prompt = conv.get_prompt() 107 | 108 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() 109 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 110 | keywords = [stop_str] 111 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) 112 | streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0) 113 | 114 | with torch.inference_mode(): 115 | thread = Thread( 116 | target=self.model.generate, 117 | kwargs=dict(inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]), 118 | ) 119 | thread.start() 120 | # workaround: second-to-last token is always " " 121 | # but we want to keep it if it's not the second-to-last token 122 | prepend_space = False 123 | for new_text in streamer: 124 | if new_text == " ": 125 | prepend_space = True 126 | continue 127 | if new_text.endswith(stop_str): 128 | new_text = new_text[: -len(stop_str)].strip() 129 | prepend_space = False 130 | elif prepend_space: 131 | new_text = " " + new_text 132 | prepend_space = False 133 | if len(new_text): 134 | yield new_text 135 | if prepend_space: 136 | yield " " 137 | thread.join() 138 | 139 | 140 | def load_image(image_file): 141 | if image_file.startswith("http") or image_file.startswith("https"): 142 | response = requests.get(image_file) 143 | image = Image.open(BytesIO(response.content)).convert("RGB") 144 | else: 145 | image = Image.open(image_file).convert("RGB") 146 | return image 147 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py: -------------------------------------------------------------------------------- 1 | # Based on EVA, BEIT, timm and DeiT code bases 2 | # https://github.com/baaivision/EVA 3 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | # https://github.com/microsoft/unilm/tree/master/beit 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | # not tested yet 9 | import math 10 | from transformers import CLIPImageProcessor 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as checkpoint 16 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 17 | from .eva_clip import create_model_and_transforms, get_model_config 18 | import torch 19 | import torchvision 20 | import time 21 | 22 | from llava.utils import rank0_print 23 | 24 | 25 | class EvaViTWrapper(nn.Module): 26 | def __init__(self, vision_tower, args, delay_load=False): 27 | super().__init__() 28 | 29 | self.is_loaded = False 30 | self.vision_tower_name = vision_tower 31 | self.pretrained = args.vision_tower_pretrained 32 | self.args = args 33 | 34 | self.select_layer = args.mm_vision_select_layer 35 | if self.select_layer < -1: 36 | self.select_layer += 1 37 | self.select_feature = getattr(args, "mm_vision_select_feature", "patch") 38 | 39 | self.model_config = get_model_config(self.vision_tower_name) 40 | 41 | if not delay_load: 42 | rank0_print(f"Loading vision tower: {vision_tower}") 43 | self.load_model() 44 | elif getattr(args, "unfreeze_mm_vision_tower", False): 45 | # TODO: better detector is needed. 46 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") 47 | self.load_model() 48 | elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: 49 | rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") 50 | self.load_model() 51 | 52 | def load_model(self): 53 | rank0_print(f"Loading: {self.vision_tower_name}") 54 | rank0_print(f"Pretrained: {self.pretrained}") 55 | time_start = time.time() 56 | model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") 57 | time_end = time.time() 58 | rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") 59 | self.device = next(model.parameters()).device 60 | self.dtype = next(model.parameters()).dtype 61 | if self.device.type != "meta": 62 | model = model.to("cuda") 63 | self.vision_tower = model.visual 64 | resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] 65 | normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] 66 | self.resize_transform_size = resize_transform.size 67 | self.image_processor = CLIPImageProcessor.from_pretrained( 68 | "openai/clip-vit-large-patch14", 69 | crop_size=resize_transform.size, 70 | size={"shortest_edge": resize_transform.size}, 71 | image_mean=list(normalize_transform.mean), 72 | image_std=list(normalize_transform.std), 73 | ) 74 | rank0_print(f"Loaded image processor: {self.image_processor}") 75 | self.vision_tower.requires_grad_(False) 76 | self.is_loaded = True 77 | 78 | def feature_select(self, image_features): 79 | select_feature_type = self.select_feature 80 | 81 | # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: 82 | # select_every_k_layer = len(image_features) // 4 83 | # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) 84 | # select_feature_type = select_feature_type.replace("slicefour_", "") 85 | # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: 86 | # select_layers = [-1, -4, -7, -10, 6] 87 | # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) 88 | # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") 89 | # else: 90 | # image_features = image_features[self.select_layer] 91 | 92 | if select_feature_type == "patch": 93 | image_features = image_features[:, 1:] 94 | elif select_feature_type == "cls_patch": 95 | image_features = image_features 96 | else: 97 | raise ValueError(f"Unexpected select feature: {select_feature_type}") 98 | return image_features 99 | 100 | def train(self, mode=True): 101 | self.training = mode 102 | 103 | if self.is_loaded: 104 | self.vision_tower.eval() 105 | 106 | def forward(self, images): 107 | if type(images) is list: 108 | image_features = [] 109 | for image in images: 110 | image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) 111 | image_features = self.feature_select(image_features).to(self.dtype) 112 | image_features.append(image_features) 113 | else: 114 | image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) 115 | image_features = self.feature_select(image_features).to(self.dtype) 116 | 117 | return image_features 118 | 119 | @property 120 | def dummy_feature(self): 121 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 122 | 123 | @property 124 | def hidden_size(self): 125 | return self.model_config["vision_cfg"]["width"] 126 | 127 | @property 128 | def num_patches(self): 129 | return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 130 | 131 | @property 132 | def num_patches_per_side(self): 133 | return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] 134 | 135 | @property 136 | def config(self): 137 | return self.model_config 138 | 139 | @property 140 | def image_size(self): 141 | return self.model_config["vision_cfg"]["image_size"] 142 | --------------------------------------------------------------------------------