├── results └── inference_results.json ├── checkpoints └── README.md ├── llava ├── __init__.py ├── train │ ├── train_mem.py │ ├── train_mem_block.py │ ├── train_xformers.py │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ └── llava_trainer.py ├── model │ ├── __init__.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── utils.py │ ├── consolidate.py │ ├── multimodal_projector │ │ └── builder.py │ ├── apply_delta.py │ ├── make_delta.py │ ├── language_model │ │ ├── llava_mpt.py │ │ ├── llava_mistral.py │ │ └── llava_llama.py │ └── builder.py ├── constants.py ├── utils.py ├── serve │ ├── batch_inference_block.py │ └── cli_final.py └── mm_utils.py ├── image └── README │ └── 1713757322703.png ├── .gitignore ├── scripts ├── zero2.json ├── zero2_offload.json ├── inference.sh ├── zero3.json ├── finetune_task_sft.sh ├── pretrain.sh ├── finetune.sh ├── finetune_block_bigsmall.sh ├── finetune_lora.sh ├── zero3_offload.json └── finetune_task_lora.sh ├── data_preprocess ├── check_image.py ├── prepare_data_test.sh ├── prepare_data_train.sh ├── prepare_data.sh ├── extract_bdd_frame_bbox_anno.py ├── extract_bdd_test_frame_bbox_anno.py ├── add_stage_prompt.py ├── chose_best_view_test.py ├── best_view_selection.py ├── filiter_data_by_area.py ├── extract_wts_test_frame_bbox_anno.py ├── extract_wts_frame_bbox_anno.py ├── generate_test_frames.py ├── shortQA_split.py ├── draw_bbox_on_frame.py ├── transform_llava_format.py └── shortQA_merge.py ├── pyproject.toml ├── block_expansion_llava_1_6.py ├── README.md └── LICENSE /results/inference_results.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | please put the finetuned model here -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /image/README/1713757322703.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/AICITY2024_Track2_AliOpenTrek_CityLLaVA/HEAD/image/README/1713757322703.png -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | from llava.train.train import train 2 | 3 | if __name__ == "__main__": 4 | train(attn_implementation="flash_attention_2") 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.so 3 | build 4 | .coverage_* 5 | *.egg-info 6 | *~ 7 | .vscode/ 8 | .idea/ 9 | 10 | image/.DS_Store 11 | llava/model/.DS_Store 12 | .DS_Store 13 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/train/train_mem_block.py: -------------------------------------------------------------------------------- 1 | from llava.train.train_block import train 2 | 3 | if __name__ == "__main__": 4 | train(attn_implementation="flash_attention_2") 5 | 6 | # from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | # replace_llama_attn_with_flash_attn() 9 | 10 | # from llava.train.train_block import train 11 | 12 | # if __name__ == "__main__": 13 | # train() 14 | 15 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "offload_optimizer": { 19 | "device": "cpu" 20 | }, 21 | "overlap_comm": true, 22 | "contiguous_gradients": true, 23 | "sub_group_size": 1e9, 24 | "reduce_bucket_size": "auto" 25 | } 26 | } -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | work_dir="./llava/serve" 4 | 5 | cd $work_dir 6 | 7 | # batch_inference_block.py & cli_final.py 8 | 9 | DATA_PATH="../../data_preprocess/data/generate_test_frames/bbox_global" 10 | LOCAL_IMAGE_DATA_PATH="../../data_preprocess/data/generate_test_frames/bbox_local" 11 | FINETUNE_MODEL="../../checkpoints/llava1_6-34b-aicity-block-single-round-bigsmall-0325" # Download the model and put it into 'checkpoints' dir. 12 | SAVE_PATH="../../results/inference_result.json" # You can change the other directory. 13 | NUM_POOL=1 # equal to your the number of GPU 14 | BEST_VIEW_MAP="../../data_preprocess/processed_anno/perspective_test_images.json" 15 | 16 | python batch_inference_block.py \ 17 | --data-path $DATA_PATH \ 18 | --local-image-data-path $LOCAL_IMAGE_DATA_PATH \ 19 | --finetune-model $FINETUNE_MODEL \ 20 | --save-path $SAVE_PATH \ 21 | --num-pool $NUM_POOL \ 22 | --best-view-map $BEST_VIEW_MAP -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_clipping": "auto", 16 | "gradient_accumulation_steps": "auto", 17 | "zero_optimization": { 18 | "stage": 3, 19 | "overlap_comm": true, 20 | "contiguous_gradients": true, 21 | "sub_group_size": 1e9, 22 | "reduce_bucket_size": "auto", 23 | "stage3_prefetch_bucket_size": "auto", 24 | "stage3_param_persistence_threshold": "auto", 25 | "stage3_max_live_parameters": 1e9, 26 | "stage3_max_reuse_distance": 1e9, 27 | "stage3_gather_16bit_weights_on_model_save": true 28 | } 29 | } -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /data_preprocess/check_image.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from tqdm import tqdm 3 | 4 | image_path = './data' 5 | data_path = './processed_anno/llava_format/wts_bdd_llava_qa_train_stage_filted.json' 6 | save_path = './processed_anno/llava_format/wts_bdd_llava_qa_train_stage_filted_checked.json' 7 | save_miss_path = './processed_anno/llava_format/wts_bdd_llava_qa_train_stage_filted_miss.json' 8 | with open(data_path, 'r') as f: 9 | data_json = json.load(f) 10 | print(f'num:{len(data_json)}') 11 | miss_data = [] 12 | new_data = [] 13 | for data in tqdm(data_json): 14 | if 'image' in data.keys(): 15 | sample_path = os.path.join(image_path, data['image']) 16 | if not os.path.exists(sample_path): 17 | print(sample_path) 18 | miss_data.append(sample_path) 19 | else: 20 | new_data.append(data) 21 | 22 | print(f'{len(data_json)} vs {len(new_data)}') 23 | 24 | with open(save_path, 'w') as f: 25 | f.write(json.dumps(new_data, indent=2, ensure_ascii=False)) 26 | 27 | with open(save_miss_path, 'w') as f1: 28 | f1.write(json.dumps(miss_data, indent=2, ensure_ascii=False)) 29 | 30 | 31 | -------------------------------------------------------------------------------- /data_preprocess/prepare_data_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num_worker=32 4 | test_root="./data/test_part" # Store original test data 5 | generate_test_frames_path="./data/generate_test_frames" # extract frames for evaluation 6 | save_folder="./processed_anno" # Store json files 7 | scale=1.5 8 | 9 | python extract_wts_test_frame_bbox_anno.py --root $test_root --save-folder $save_folder/frame_bbox_anno 10 | python extract_bdd_test_frame_bbox_anno.py --root $test_root --save-folder $save_folder/frame_bbox_anno 11 | 12 | for file in "$save_folder/frame_bbox_anno"/*test*; do 13 | python draw_bbox_on_frame.py --worker $num_worker --anno $file --scale $scale 14 | done 15 | 16 | python best_view_selection.py \ 17 | --test-root $test_root \ 18 | --save-path $save_folder/best_view_for_test.json \ 19 | 20 | python generate_test_frames.py \ 21 | --root $generate_test_frames_path \ 22 | --best-view-anno $save_folder/best_view_for_test.json \ 23 | --bdd-test-folder $test_root/WTS_DATASET_PUBLIC_TEST/external/BDD_PC_5K/bbox_global/test/public \ 24 | --wts-test-folder $test_root/WTS_DATASET_PUBLIC_TEST/bbox_global/test/public \ 25 | --save-folder $save_folder 26 | 27 | echo " Testsets prepared." -------------------------------------------------------------------------------- /scripts/finetune_task_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero3.json \ 5 | --model_name_or_path ./models/llava-v1.6-34b \ 6 | --version chatml_direct \ 7 | --data_path ./wts_bdd_llava_new_stage_noval_filter_check_0325.json \ 8 | --image_folder ./dataset \ 9 | --vision_tower ./models/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --mm_vision_select_layer -2 \ 12 | --mm_use_im_start_end False \ 13 | --mm_use_im_patch_token False \ 14 | --image_aspect_ratio pad \ 15 | --group_by_modality_length True \ 16 | --bf16 True \ 17 | --output_dir ./work_dirs/llava1_6-34b-aicity-sft \ 18 | --num_train_epochs 2 \ 19 | --per_device_train_batch_size 2 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 8 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "epoch" \ 24 | --save_total_limit 6 \ 25 | --learning_rate 2e-5 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 8 \ 34 | --lazy_preprocess True \ 35 | --report_to tensorboard 36 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero2.json \ 5 | --model_name_or_path lmsys/vicuna-13b-v1.5 \ 6 | --version plain \ 7 | --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 8 | --image_folder ./playground/data/LLaVA-Pretrain/images \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --tune_mm_mlp_adapter True \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoints/llava-v1.5-13b-pretrain \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 32 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero3.json \ 5 | --model_name_or_path lmsys/vicuna-13b-v1.5 \ 6 | --version v1 \ 7 | --data_path ./playground/data/llava_v1_5_mix665k.json \ 8 | --image_folder ./playground/data \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir ./checkpoints/llava-v1.5-13b \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 16 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | -------------------------------------------------------------------------------- /scripts/finetune_block_bigsmall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --include localhost:0,1,2,3,4,5,6,7 llava/train/train_mem_block.py \ 4 | --deepspeed ./scripts/zero3.json \ 5 | --zero_stage 3 \ 6 | --model_name_or_path ./models/llava-v1.6-34b-12block \ 7 | --version chatml_direct \ 8 | --data_path './data_preprocess/processed_anno/llava_format/wts_bdd_llava_qa_train_stage_filted_checked.json' \ 9 | --image_folder ./data_preprocess/data \ 10 | --mm_vision_select_layer -2 \ 11 | --mm_use_im_start_end False \ 12 | --mm_use_im_patch_token False \ 13 | --image_aspect_ratio bigsmall \ 14 | --mm_patch_merge_type flat \ 15 | --group_by_modality_length True \ 16 | --bf16 True \ 17 | --output_dir ./checkpoints/llava1_6-34b-aicity-block-bigsmall \ 18 | --num_train_epochs 1 \ 19 | --per_device_train_batch_size 4 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 2 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "epoch" \ 24 | --save_total_limit 2 \ 25 | --learning_rate 2e-4 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 8 \ 34 | --lazy_preprocess True \ 35 | --report_to tensorboard 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.2.2.post1" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.0.1", "torchvision==0.15.2", 17 | "transformers==4.37.0", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.21.0", "peft", "bitsandbytes", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 21 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 22 | "deepspeed==0.12.6", "ninja", "tensorboardX", 23 | "tqdm", "dashscope", "openai", "opencv-python", "decord" 24 | ] 25 | 26 | [project.urls] 27 | "Homepage" = "https://github.com/qingchunlizhi/AICITY2024_Track2_AliOpenTrek_CityLLaVA" 28 | "Bug Tracker" = "https://github.com/qingchunlizhi/AICITY2024_Track2_AliOpenTrek_CityLLaVA/issues" 29 | 30 | [tool.setuptools.packages.find] 31 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 32 | 33 | [tool.wheel] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | -------------------------------------------------------------------------------- /data_preprocess/prepare_data_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num_worker=32 4 | root="./data" 5 | save_folder="./processed_anno" # Store json files 6 | splits=("train" "val") 7 | scale=1.5 8 | 9 | for split in "${splits[@]}"; do 10 | python extract_wts_frame_bbox_anno.py --root $root --save-folder $save_folder/frame_bbox_anno --split $split 11 | python extract_bdd_frame_bbox_anno.py --root $root --save-folder $save_folder/frame_bbox_anno --split $split 12 | done 13 | 14 | for file in "$save_folder/frame_bbox_anno"/*train*; do 15 | python draw_bbox_on_frame.py --worker $num_worker --anno $file --scale $scale 16 | done 17 | 18 | for file in "$save_folder/frame_bbox_anno"/*val*; do 19 | python draw_bbox_on_frame.py --worker $num_worker --anno $file --scale $scale 20 | done 21 | 22 | for split in "${splits[@]}"; do 23 | python transform_llava_format.py \ 24 | --root $root \ 25 | --save-folder $save_folder/llava_format \ 26 | --split $split \ 27 | --wts-global-image-path $root/WTS/bbox_global \ 28 | --bdd-global-image-path $root/BDD_PC_5k/bbox_global 29 | done 30 | 31 | # generate shortQA 32 | API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx" 33 | MODEL="Qwen" 34 | 35 | python shortQA_split.py --model $MODEL --api-key $API_KEY 36 | python shortQA_merge.py 37 | 38 | # data filter 39 | python add_stage_prompt.py 40 | python filiter_data_by_area.py 41 | python check_image.py 42 | 43 | echo " Trainsets prepared." -------------------------------------------------------------------------------- /scripts/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path lmsys/vicuna-13b-v1.5 \ 7 | --version v1 \ 8 | --data_path ./playground/data/llava_v1_5_mix665k.json \ 9 | --image_folder ./playground/data \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ 12 | --mm_projector_type mlp2x_gelu \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --image_aspect_ratio pad \ 17 | --group_by_modality_length True \ 18 | --bf16 True \ 19 | --output_dir ./checkpoints/llava-v1.5-13b-lora \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 16 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 1 \ 24 | --evaluation_strategy "no" \ 25 | --save_strategy "steps" \ 26 | --save_steps 50000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 2e-4 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --lr_scheduler_type "cosine" \ 32 | --logging_steps 1 \ 33 | --tf32 True \ 34 | --model_max_length 2048 \ 35 | --gradient_checkpointing True \ 36 | --dataloader_num_workers 4 \ 37 | --lazy_preprocess True \ 38 | --report_to wandb 39 | -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /scripts/finetune_task_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 256 --lora_alpha 512 \ 5 | --mm_projector_lr 2e-5 \ 6 | --deepspeed ./scripts/zero3.json \ 7 | --model_name_or_path /mnt/workspace/workgroup/chengxiang/models/llava-v1.6-34b \ 8 | --version chatml_direct \ 9 | --data_path /mnt/workspace/workgroup/chengxiang/datasets/aicity2024/process/mixed_wts_bdd_trainval.json \ 10 | --image_folder /mnt/workspace/workgroup/chenghao/video_analysis/dataset \ 11 | --vision_tower /mnt/workspace/workgroup/chengxiang/models/clip-vit-large-patch14-336 \ 12 | --mm_projector_type mlp2x_gelu \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --image_aspect_ratio pad \ 17 | --group_by_modality_length True \ 18 | --bf16 True \ 19 | --output_dir /mnt/workspace/workgroup/chengxiang/work_dirs/llava1_6-34b-aicity-lora256-pad-multiround-0321 \ 20 | --num_train_epochs 2 \ 21 | --per_device_train_batch_size 8 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 2 \ 24 | --evaluation_strategy "no" \ 25 | --save_strategy "epoch" \ 26 | --save_total_limit 6 \ 27 | --learning_rate 2e-4 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 4096 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to tensorboard 38 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /data_preprocess/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | num_worker=32 4 | root="./data" 5 | test_root="./data/test_part" # Store original test data 6 | generate_test_frames_path="./data/generate_test_frames" 7 | save_folder="./processed_anno" # Store json files 8 | splits=("train" "val") 9 | scale=1.5 10 | 11 | for split in "${splits[@]}"; do 12 | python extract_wts_frame_bbox_anno.py --root $root --save-folder $save_folder/frame_bbox_anno --split $split 13 | python extract_bdd_frame_bbox_anno.py --root $root --save-folder $save_folder/frame_bbox_anno --split $split 14 | done 15 | 16 | python extract_wts_test_frame_bbox_anno.py --root $test_root --save-folder $save_folder/frame_bbox_anno 17 | python extract_bdd_test_frame_bbox_anno.py --root $test_root --save-folder $save_folder/frame_bbox_anno 18 | 19 | for file in "$save_folder/frame_bbox_anno"/*; do 20 | python draw_bbox_on_frame.py --worker $num_worker --anno $file --scale $scale 21 | done 22 | 23 | for split in "${splits[@]}"; do 24 | python transform_llava_format.py \ 25 | --root $root \ 26 | --save-folder $save_folder/llava_format \ 27 | --split $split \ 28 | --wts-global-image-path ./data/WTS/bbox_global \ 29 | --bdd-global-image-path ./data/BDD_PC_5k/bbox_global 30 | done 31 | 32 | python best_view_selection.py \ 33 | --test-root $test_root \ 34 | --save-path $save_folder/best_view_for_test.json \ 35 | 36 | python generate_test_frames.py \ 37 | --root $generate_test_frames_path \ 38 | --best-view-anno $save_folder/best_view_for_test.json \ 39 | --bdd-test-folder ./data/test_part/WTS_DATASET_PUBLIC_TEST/external/BDD_PC_5K/bbox_global/test/public/ \ 40 | --wts-test-folder ./data/test_part/WTS_DATASET_PUBLIC_TEST/bbox_global/test/public/ \ 41 | --save-folder $save_folder 42 | 43 | # generate shortQA 44 | API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxx" 45 | MODEL="Qwen" 46 | 47 | python shortQA_split.py --model $MODEL --api-key $API_KEY 48 | python shortQA_merge.py 49 | 50 | # data filter 51 | python add_stage_prompt.py 52 | python filiter_data_by_area.py 53 | python check_image.py 54 | 55 | echo "Done." -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /data_preprocess/extract_bdd_frame_bbox_anno.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--root', type=str, default='/mnt/data/AICITY2024/', help='data root path') 8 | parser.add_argument('--split', type=str, default='train') 9 | parser.add_argument('--save-folder', type=str, default='processed_anno', help='dirname for saving json file') 10 | 11 | args = parser.parse_args() 12 | 13 | video_root_path = os.path.join(args.root, 'BDD_PC_5k/videos', args.split) 14 | annotation_path = os.path.join(args.root, 'BDD_PC_5k/annotations/caption', args.split) 15 | bbox_path = os.path.join(args.root, 'BDD_PC_5k/annotations/bbox_annotated', args.split) 16 | 17 | video_with_bbox_results = dict() 18 | 19 | for item in tqdm(os.listdir(video_root_path)): 20 | video_path = os.path.join(video_root_path, item) 21 | camera_base = item.replace('.mp4', '') 22 | overhead_caption_anno_path = os.path.join(annotation_path, f'{camera_base}_caption.json') 23 | 24 | # vehicle bbox extraction 25 | assert os.path.exists(overhead_caption_anno_path) 26 | vehicle_annotation = json.load(open(overhead_caption_anno_path))['event_phase'] 27 | fps = json.load(open(overhead_caption_anno_path))['fps'] 28 | start_time, end_time = None, None 29 | for phase in vehicle_annotation: 30 | if not start_time: 31 | start_time = float(phase['start_time']) 32 | else: 33 | start_time = min(float(phase['start_time']), start_time) 34 | if not end_time: 35 | end_time = float(phase['end_time']) 36 | else: 37 | end_time = max(float(phase['end_time']), end_time) 38 | 39 | video_with_bbox_results[video_path] = dict(fps=fps, start_time=start_time, end_time=end_time, ped_bboxes=dict(), veh_bboxes=dict(), phase_number=dict()) 40 | pedestrian_bbox_anno_path = os.path.join(bbox_path, f'{camera_base}_bbox.json') 41 | 42 | if os.path.exists(pedestrian_bbox_anno_path): 43 | pedestrian_bbox = json.load(open(pedestrian_bbox_anno_path))['annotations'] 44 | for bbox in pedestrian_bbox: 45 | video_with_bbox_results[video_path]['ped_bboxes'][bbox['image_id']] = bbox['bbox'] 46 | video_with_bbox_results[video_path]['phase_number'][bbox['image_id']] = bbox['phase_number'] 47 | 48 | os.makedirs(args.save_folder, exist_ok=True) 49 | with open(os.path.join(args.save_folder, f'bdd_{args.split}_all_video_with_bbox_anno_first_frame.json'), 'w') as f: 50 | f.write(json.dumps(video_with_bbox_results, indent=4)) -------------------------------------------------------------------------------- /block_expansion_llava_1_6.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from llava.model.builder import load_pretrained_model 8 | from llava.mm_utils import get_model_name_from_path 9 | 10 | 11 | def main(): 12 | # Set up the argument parser 13 | parser = argparse.ArgumentParser(description="Receive deepen model's args") 14 | parser.add_argument("--model_path", default='./models/llava-v1.6-34b', type=str, help="original model path") 15 | parser.add_argument("--output_path", default='./models/llava-v1.6-34b-12block', type=str, help="deepened model ckpt save path") 16 | parser.add_argument("--original_layers", default=60, type=int, help="original model num layers") 17 | parser.add_argument("--layers", default=72, type=int, help="deepen model num layers") 18 | 19 | # Parse the arguments 20 | args = parser.parse_args() 21 | cache_dir = 'cache_dir' 22 | device = 'cuda' 23 | load_4bit, load_8bit = False, False 24 | model_name = get_model_name_from_path(args.model_path) 25 | tokenizer, model, processor, _ = load_pretrained_model(args.model_path, None, model_name, load_8bit, load_4bit, device=device, cache_dir=cache_dir) 26 | ckpt = model.state_dict() 27 | 28 | split = int(args.original_layers / (args.layers - args.original_layers)) 29 | layer_cnt = 0 30 | 31 | output = {} 32 | for i in tqdm(range(args.original_layers)): 33 | for k in ckpt: 34 | if ('layers.' + str(i) + '.') in k: 35 | output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k] 36 | layer_cnt += 1 37 | if (i+1) % split == 0: 38 | for k in ckpt: 39 | if ('layers.' + str(i) + '.') in k: 40 | if 'down_proj' in k or 'o_proj' in k: 41 | output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k]) 42 | else: 43 | output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k] 44 | 45 | 46 | layer_cnt += 1 47 | 48 | assert layer_cnt==args.layers 49 | add_layer = [(split+1)*i+split for i in range(0, args.layers-args.original_layers)] 50 | print(add_layer) 51 | for k in ckpt: 52 | if not 'layers' in k: 53 | output[k] = ckpt[k] 54 | if not os.path.exists(args.output_path): 55 | os.makedirs(args.output_path) 56 | torch.save(output, args.output_path + '/pytorch_model.bin') 57 | 58 | if __name__ == "__main__": 59 | main() -------------------------------------------------------------------------------- /data_preprocess/extract_bdd_test_frame_bbox_anno.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--root', type=str, default='/mnt/data/AICITY2024/', help='data root path') 8 | parser.add_argument('--save-folder', type=str, default='processed_anno', help='dirname for saving json file') 9 | 10 | args = parser.parse_args() 11 | 12 | video_root_path = os.path.join(args.root, 'WTS_DATASET_PUBLIC_TEST/external/BDD_PC_5K/videos/test/public') 13 | annotation_path = os.path.join(args.root, 'WTS_DATASET_PUBLIC_TEST/external/BDD_PC_5K/annotations/caption/test/public_challenge') 14 | bbox_path = os.path.join(args.root, 'WTS_DATASET_PUBLIC_TEST_BBOX/external/BDD_PC_5K/annotations/bbox_annotated/test/public') 15 | 16 | video_with_bbox_results = dict() 17 | 18 | for item in tqdm(os.listdir(video_root_path)): 19 | video_path = os.path.join(video_root_path, item) 20 | camera_base = item.replace('.mp4', '') 21 | overhead_caption_anno_path = os.path.join(annotation_path, f'{camera_base}_caption.json') 22 | 23 | # vehicle bbox extraction 24 | assert os.path.exists(overhead_caption_anno_path) 25 | vehicle_annotation = json.load(open(overhead_caption_anno_path))['event_phase'] 26 | fps = json.load(open(overhead_caption_anno_path))['fps'] 27 | start_time, end_time = None, None 28 | for phase in vehicle_annotation: 29 | if not start_time: 30 | start_time = float(phase['start_time']) 31 | else: 32 | start_time = min(float(phase['start_time']), start_time) 33 | if not end_time: 34 | end_time = float(phase['end_time']) 35 | else: 36 | end_time = max(float(phase['end_time']), end_time) 37 | 38 | 39 | video_with_bbox_results[video_path] = dict(fps=fps, start_time=start_time, end_time=end_time, ped_bboxes=dict(), veh_bboxes=dict(), phase_number=dict()) 40 | pedestrian_bbox_anno_path = os.path.join(bbox_path, f'{camera_base}_bbox.json') 41 | 42 | if os.path.exists(pedestrian_bbox_anno_path): 43 | pedestrian_bbox = json.load(open(pedestrian_bbox_anno_path))['annotations'] 44 | for bbox in pedestrian_bbox: 45 | video_with_bbox_results[video_path]['ped_bboxes'][bbox['image_id']] = bbox['bbox'] 46 | video_with_bbox_results[video_path]['phase_number'][bbox['image_id']] = bbox['phase_number'] 47 | 48 | os.makedirs(args.save_folder, exist_ok=True) 49 | with open(os.path.join(args.save_folder, 'bdd_test_all_video_with_bbox_anno_first_frame.json'), 'w') as f: 50 | f.write(json.dumps(video_with_bbox_results, indent=4)) -------------------------------------------------------------------------------- /data_preprocess/add_stage_prompt.py: -------------------------------------------------------------------------------- 1 | import os, json, glob 2 | from tqdm import tqdm 3 | import random 4 | import numpy as np 5 | 6 | def random_shuffle_conversations(conversations): 7 | question_nums = len(conversations) 8 | assert question_nums % 2 == 0, 'Pairs incomplete' 9 | 10 | indices = np.arange(question_nums).reshape(-1, 2).tolist() 11 | random.shuffle(indices) 12 | indices = np.asarray(indices).reshape(-1).tolist() 13 | shuffled_conversations = list() 14 | for ind in indices: 15 | shuffled_conversations.append(conversations[ind]) 16 | return shuffled_conversations 17 | 18 | 19 | if __name__ == '__main__': 20 | labels_path = './processed_anno/llava_format/wts_bdd_llava_qa_train.json' 21 | save_path = './processed_anno/llava_format/wts_bdd_llava_qa_train_stage.json' 22 | 23 | 24 | # wts_prompt = "\nThis is an image in '{viewpoint}' stage. Pay attention to the pedestrian in the green bounding box and the vehicle in the blue bounding box. Note that the bounding box may not exist, then answer the following questions:\n{question}" 25 | # bdd_prompt = "\nThis is an image in '{viewpoint}' stage. Pay attention to the pedestrian in the green bounding box and the ego-vehicle. Note that the bounding box may not exist, then answer the following questions:\n{question}" 26 | 27 | wts_prompt = "\nThis is an image in '{viewpoint}' stage. {question}" 28 | bdd_prompt = "\nThis is an image in '{viewpoint}' stage. {question}" 29 | 30 | 31 | data_json = json.load(open(labels_path)) 32 | for data in tqdm(data_json): 33 | data['conversations'] = random_shuffle_conversations(data['conversations']) 34 | viewpoint = data['image'].split('/')[-1].split('_')[-1].split('.')[0] 35 | if 'BDD_PC_5k' in data['image'] or 'vehicle_view' in data['image']: 36 | prompt = bdd_prompt 37 | else: 38 | prompt = wts_prompt 39 | for i, cc in enumerate(data['conversations']): 40 | if cc['from'] == 'human': 41 | cc['value'] = cc['value'].replace('\n', '').replace('the green box', 'the green bounding box').replace('the blue box', 'the blue bounding box') 42 | if 'BDD_PC_5k' in data['image'] or 'vehicle_view' in data['image']: 43 | cc['value'] = cc['value'].replace('the vehicle with the blue bounding box', 'the ego-vehicle').replace('the vehicle with the blue box', 'the ego-vehicle') 44 | if i == 0: 45 | cc['value'] = prompt.format(viewpoint = viewpoint, question = cc['value']) 46 | 47 | print(len(data_json)) 48 | with open(save_path, 'w') as f: 49 | f.write(json.dumps(data_json, indent=2, ensure_ascii=False)) -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | 29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 31 | self.vision_tower.requires_grad_(False) 32 | 33 | self.is_loaded = True 34 | 35 | def feature_select(self, image_forward_outs): 36 | image_features = image_forward_outs.hidden_states[self.select_layer] 37 | if self.select_feature == 'patch': 38 | image_features = image_features[:, 1:] 39 | elif self.select_feature == 'cls_patch': 40 | image_features = image_features 41 | else: 42 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 43 | return image_features 44 | 45 | @torch.no_grad() 46 | def forward(self, images): 47 | if type(images) is list: 48 | image_features = [] 49 | for image in images: 50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 51 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 52 | image_features.append(image_feature) 53 | else: 54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 55 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 56 | 57 | return image_features 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.device 70 | 71 | @property 72 | def config(self): 73 | if self.is_loaded: 74 | return self.vision_tower.config 75 | else: 76 | return self.cfg_only 77 | 78 | @property 79 | def hidden_size(self): 80 | return self.config.hidden_size 81 | 82 | @property 83 | def num_patches_per_side(self): 84 | return self.config.image_size // self.config.patch_size 85 | 86 | @property 87 | def num_patches(self): 88 | return (self.config.image_size // self.config.patch_size) ** 2 89 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /data_preprocess/chose_best_view_test.py: -------------------------------------------------------------------------------- 1 | import os, json, csv, glob 2 | from tqdm import tqdm 3 | from collections import defaultdict 4 | 5 | 6 | def get_best_view_wts(ann_path, bbox_path, scnearios, reference_views): 7 | best_view_video = {} 8 | for scneario in tqdm(scnearios): 9 | if '.DS_Store' in scneario: 10 | continue 11 | if '_normal_' in scneario: 12 | if os.path.exists(os.path.join(bbox_path, f'normal_trimmed/{scneario}/overhead_view')) or not os.path.exists(os.path.join(bbox_path, f'normal_trimmed/{scneario}/vehicle_view')): 13 | best_view_video[scneario] = scneario + '.mp4' 14 | else: 15 | best_view_video[scneario] = scneario +'_vehicle_view.mp4' 16 | else: 17 | if scneario == '20231006_18_CN29_T1': 18 | print('f') 19 | try: 20 | overhead_view_json = json.load(open(glob.glob(os.path.join(ann_path, f'{scneario}/overhead_view/*.json'))[0])) 21 | except: 22 | overhead_view_json = None 23 | 24 | views = [] 25 | for overhand in overhead_view_json['overhead_videos']: 26 | if scneario in reference_views.keys(): 27 | if overhand in reference_views[scneario]: 28 | views.append(overhand) 29 | else: 30 | print(f'no reference view: {scneario}') 31 | views.append(overhand) 32 | best_view_score = 0 33 | best_view = None 34 | for view in views: 35 | if os.path.exists(os.path.join(bbox_path, f"{scneario}/overhead_view/{view.replace('.mp4', '')}_bbox.json")): 36 | bbox = json.load(open(os.path.join(bbox_path, f"{scneario}/overhead_view/{view.replace('.mp4', '')}_bbox.json"))) 37 | elif os.path.exists(os.path.join(bbox_path, f"{scneario}/vehicle_view/{scneario}_vehicle_view_bbox.json")): 38 | bbox = json.load(open(os.path.join(bbox_path, f"{scneario}/vehicle_view/{scneario}_vehicle_view_bbox.json"))) 39 | else: 40 | print(f'no bbox: {scneario}') 41 | continue 42 | 43 | if len(bbox["annotations"]) == 5: 44 | avg_human_area = sum([box['bbox'][2]*box['bbox'][3] for box in bbox["annotations"]])/5. 45 | if avg_human_area > best_view_score: 46 | best_view_score = avg_human_area 47 | best_view = view 48 | 49 | if best_view == None and os.path.exists(os.path.join(bbox_path, f"{scneario}/vehicle_view/{scneario}_vehicle_view_bbox.json")): 50 | best_view = scneario +'_vehicle_view.mp4' 51 | else: 52 | avg_human_area = sum([box['bbox'][2]*box['bbox'][3] for box in bbox["annotations"]])/len(bbox["annotations"]) 53 | if avg_human_area > best_view_score: 54 | best_view_score = avg_human_area 55 | best_view = view 56 | 57 | # We found that the bounding boxes of 20230728_13_CN21_T1_Camera2_5.mp4 and 20230728_13_CN21_T2_Camera2_5 is incorrect 58 | if scneario == '20230728_13_CN21_T1' or scneario == '20230728_13_CN21_T2': 59 | best_view=scneario +'_vehicle_view.mp4' 60 | 61 | best_view_video[scneario] = best_view 62 | 63 | return best_view_video 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | wts_ann_path = './data/WTS/annotations/caption/test/public_challenge' 69 | wts_bbox_path = './data/WTS/annotations/bbox_annotated/pedestrian/test/public' 70 | bdd_video_path = './data/BDD_PC_5K/videos/test' 71 | reference_view_path = './data/test_part/view_used_as_main_reference_for_multiview_scenario.csv' 72 | save_path = './data/test_part/best_view_for_test.json' 73 | 74 | # get the official recommended perspectives 75 | with open(reference_view_path, 'r') as file: 76 | reference_views = {} 77 | reader = csv.reader(file) 78 | next(reader) 79 | for row in reader: 80 | reference_views[row[0]] = row[1:] 81 | 82 | rest_videos = defaultdict(list) 83 | 84 | # get the best bdd views 85 | scnearios1 = os.listdir(wts_ann_path) 86 | scnearios1.remove('normal_trimmed') 87 | best_view_wts1 = get_best_view_wts(wts_ann_path, wts_bbox_path, scnearios1, reference_views) 88 | rest_videos.update(best_view_wts1) 89 | 90 | scnearios2 = os.listdir(os.path.join(wts_ann_path, 'normal_trimmed')) 91 | best_view_wts2 = get_best_view_wts(wts_ann_path, wts_bbox_path, scnearios2, reference_views) 92 | rest_videos.update(best_view_wts2) 93 | 94 | # get the best bdd views 95 | for bdd_video in os.listdir(bdd_video_path): 96 | rest_videos[bdd_video.split('.')[0]] = bdd_video 97 | 98 | with open(save_path, 'w') as f: 99 | f.write(json.dumps(rest_videos, indent=2, ensure_ascii=False)) 100 | -------------------------------------------------------------------------------- /data_preprocess/best_view_selection.py: -------------------------------------------------------------------------------- 1 | import os, json, csv, glob 2 | from tqdm import tqdm 3 | from collections import defaultdict 4 | import argparse 5 | 6 | def get_best_view_wts(ann_path, bbox_path, scnearios, reference_views): 7 | best_view_video = {} 8 | for scneario in tqdm(scnearios): 9 | if '.DS_Store' in scneario: 10 | continue 11 | if '_normal_' in scneario: 12 | if os.path.exists(os.path.join(bbox_path, f'normal_trimmed/{scneario}/overhead_view')) or not os.path.exists(os.path.join(bbox_path, f'normal_trimmed/{scneario}/vehicle_view')): 13 | best_view_video[scneario] = scneario + '.mp4' 14 | else: 15 | best_view_video[scneario] = scneario +'_vehicle_view.mp4' 16 | else: 17 | if scneario == '20231006_18_CN29_T1': 18 | print('f') 19 | try: 20 | overhead_view_json = json.load(open(glob.glob(os.path.join(ann_path, f'{scneario}/overhead_view/*.json'))[0])) 21 | except: 22 | overhead_view_json = None 23 | 24 | views = [] 25 | for overhand in overhead_view_json['overhead_videos']: 26 | if scneario in reference_views.keys(): 27 | if overhand in reference_views[scneario]: 28 | views.append(overhand) 29 | else: 30 | print(f'no reference view: {scneario}') 31 | views.append(overhand) 32 | best_view_score = 0 33 | best_view = None 34 | for view in views: 35 | if os.path.exists(os.path.join(bbox_path, f"{scneario}/overhead_view/{view.replace('.mp4', '')}_bbox.json")): 36 | bbox = json.load(open(os.path.join(bbox_path, f"{scneario}/overhead_view/{view.replace('.mp4', '')}_bbox.json"))) 37 | elif os.path.exists(os.path.join(bbox_path, f"{scneario}/vehicle_view/{scneario}_vehicle_view_bbox.json")): 38 | bbox = json.load(open(os.path.join(bbox_path, f"{scneario}/vehicle_view/{scneario}_vehicle_view_bbox.json"))) 39 | else: 40 | print(f'no bbox: {scneario}') 41 | continue 42 | 43 | if len(bbox["annotations"]) == 5: 44 | avg_human_area = sum([box['bbox'][2]*box['bbox'][3] for box in bbox["annotations"]])/5. 45 | if avg_human_area > best_view_score: 46 | best_view_score = avg_human_area 47 | best_view = view 48 | 49 | if best_view == None and os.path.exists(os.path.join(bbox_path, f"{scneario}/vehicle_view/{scneario}_vehicle_view_bbox.json")): 50 | best_view = scneario +'_vehicle_view.mp4' 51 | else: 52 | avg_human_area = sum([box['bbox'][2]*box['bbox'][3] for box in bbox["annotations"]])/len(bbox["annotations"]) 53 | if avg_human_area > best_view_score: 54 | best_view_score = avg_human_area 55 | best_view = view 56 | 57 | # We found that the bounding boxes of 20230728_13_CN21_T1_Camera2_5.mp4 and 20230728_13_CN21_T2_Camera2_5 is incorrect 58 | if scneario == '20230728_13_CN21_T1' or scneario == '20230728_13_CN21_T2': 59 | best_view=scneario +'_vehicle_view.mp4' 60 | 61 | best_view_video[scneario] = best_view 62 | 63 | return best_view_video 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--test-root', type=str, default='./data/test_part') 70 | parser.add_argument('--save-path', type=str, default='./processed_anno/best_view_for_test.json') 71 | args = parser.parse_args() 72 | wts_ann_path = os.path.join(args.test_root, 'WTS_DATASET_PUBLIC_TEST/annotations/caption/test/public_challenge') 73 | wts_bbox_path = os.path.join(args.test_root, 'WTS_DATASET_PUBLIC_TEST_BBOX/annotations/bbox_annotated/pedestrian/test/public') 74 | bdd_video_path = os.path.join(args.test_root, 'WTS_DATASET_PUBLIC_TEST/external/BDD_PC_5K/videos/test/public') 75 | reference_view_path = os.path.join(args.test_root, 'view_used_as_main_reference_for_multiview_scenario.csv') 76 | save_path = args.save_path 77 | 78 | # get the official recommended perspectives 79 | with open(reference_view_path, 'r') as file: 80 | reference_views = {} 81 | reader = csv.reader(file) 82 | next(reader) 83 | for row in reader: 84 | reference_views[row[0]] = row[1:] 85 | 86 | rest_videos = defaultdict(list) 87 | 88 | # get the best bdd views 89 | scnearios1 = os.listdir(wts_ann_path) 90 | scnearios1.remove('normal_trimmed') 91 | best_view_wts1 = get_best_view_wts(wts_ann_path, wts_bbox_path, scnearios1, reference_views) 92 | rest_videos.update(best_view_wts1) 93 | 94 | scnearios2 = os.listdir(os.path.join(wts_ann_path, 'normal_trimmed')) 95 | best_view_wts2 = get_best_view_wts(wts_ann_path, wts_bbox_path, scnearios2, reference_views) 96 | rest_videos.update(best_view_wts2) 97 | 98 | # get the best bdd views 99 | for bdd_video in os.listdir(bdd_video_path): 100 | rest_videos[bdd_video.split('.')[0]] = bdd_video 101 | 102 | with open(save_path, 'w') as f: 103 | f.write(json.dumps(rest_videos, indent=2, ensure_ascii=False)) 104 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /data_preprocess/filiter_data_by_area.py: -------------------------------------------------------------------------------- 1 | import os, json, csv 2 | from tqdm import tqdm 3 | 4 | part = 'train' 5 | root_path = './data' 6 | reference_view_path = './data/test_part/view_used_as_main_reference_for_multiview_scenario.csv' 7 | 8 | data_path = './processed_anno/llava_format/wts_bdd_llava_qa_train_stage.json' 9 | save_path = './processed_anno/llava_format/wts_bdd_llava_qa_train_stage_filted.json' 10 | 11 | area_thr = 1000 12 | stage_map = {'prerecognition': 0, 'recognition': 1, 'judgement': 2, 'action': 3, 'avoidance': 4} 13 | 14 | # good_view for overhead 15 | with open(reference_view_path, 'r') as file: 16 | reference_views = {} 17 | reader = csv.reader(file) 18 | next(reader) 19 | for row in reader: 20 | reference_views[row[0]] = row[1:] 21 | 22 | 23 | with open(data_path, 'r') as f: 24 | data_json = json.load(f) 25 | 26 | new_data = [] 27 | for data in tqdm(data_json): 28 | view = data['image'].split('/')[-2] if 'WTS' in data['image'] else data['image'].split('/')[-1].split('_')[0] 29 | if 'WTS' in data['image'] and 'overhead_view' in data['image']: 30 | if data['id'] in reference_views.keys(): # 监控视角过滤 31 | if view + '.mp4' not in reference_views[data['id']]: 32 | print(f"view filter: {data['image']}, {view}") 33 | continue 34 | 35 | satge = data['image'].split('/')[-1].split('.')[0].split('_')[-1] 36 | 37 | pedestrian_box_path, vehicle_box_path = '', '' 38 | pedestrian_bbox, vehicle_bbox = '', '' 39 | if 'WTS' in data['image']: 40 | if 'normal_trimmed' not in data['image']: 41 | if 'overhead_view' in data['image']: 42 | pedestrian_box_path = f"{root_path}/WTS/annotations/bbox_annotated/pedestrian/{part}/{data['id']}/overhead_view/{view}_bbox.json" 43 | vehicle_box_path = f"{root_path}/WTS/annotations/bbox_annotated/vehicle/{part}/{data['id']}/overhead_view/{view}_bbox.json" 44 | elif 'vehicle_view' in data['image']: 45 | pedestrian_box_path = f"{root_path}/WTS/annotations/bbox_annotated/pedestrian/{part}/{data['id']}/vehicle_view/{view}_bbox.json" 46 | vehicle_box_path = f"{root_path}/WTS/annotations/bbox_annotated/vehicle/{part}/{data['id']}/vehicle_view/{view}_bbox.json" 47 | else: 48 | if 'overhead_view' in data['image']: 49 | pedestrian_box_path = f"{root_path}/WTS/annotations/bbox_annotated/pedestrian/{part}/normal_trimmed/{data['id']}/overhead_view/{view}_bbox.json" 50 | vehicle_box_path = f"{root_path}/WTS/annotations/bbox_annotated/vehicle/{part}/normal_trimmed/{data['id']}/overhead_view/{view}_bbox.json" 51 | elif 'vehicle_view' in data['image']: 52 | pedestrian_box_path = f"{root_path}/WTS/annotations/bbox_annotated/pedestrian/{part}/normal_trimmed/{data['id']}/vehicle_view/{view}_bbox.json" 53 | vehicle_box_path = f"{root_path}/WTS/annotations/bbox_annotated/vehicle/{part}/normal_trimmed/{data['id']}/vehicle_view/{view}_bbox.json" 54 | 55 | elif 'BDD_PC_5k' in data['image']: 56 | pedestrian_box_path = f"{root_path}/BDD_PC_5k/annotations/bbox_annotated/{part}/{view}_bbox.json" 57 | 58 | try: 59 | pedestrian_bbox = json.load(open(pedestrian_box_path)) 60 | except: 61 | print(f"no pedestrian json filter: {data['image']}") 62 | # miss_data.append(f"no pedestrian json filter: {data['image']}") 63 | try: 64 | vehicle_bbox = json.load(open(vehicle_box_path)) 65 | except: 66 | pass 67 | # print(f"no vehicle json filter: {data['image']}") 68 | # miss_data.append(f"no vehicle json filter: {data['image']}") 69 | 70 | if pedestrian_bbox != '': 71 | pedestrian_box = '' 72 | for single_box in pedestrian_bbox['annotations']: 73 | if str(stage_map[satge]) == single_box['phase_number'] or stage_map[satge] == single_box['phase_number']: 74 | pedestrian_box = single_box 75 | if pedestrian_box != '': 76 | human_area = pedestrian_box['bbox'][2]*pedestrian_box['bbox'][3] 77 | else: 78 | human_area = 0 79 | else: 80 | human_area = 0 81 | # print(f"no human box filter: {data['image']}") 82 | # miss_data.append(f"no human box filter: {data['image']}") 83 | 84 | 85 | if vehicle_bbox != '': 86 | vehicle_box='' 87 | for single_box in vehicle_bbox['annotations']: 88 | if str(stage_map[satge]) == single_box['phase_number'] or stage_map[satge] == single_box['phase_number']: 89 | vehicle_box = single_box 90 | if vehicle_box != '': 91 | vehicle_area = vehicle_box['bbox'][2]*vehicle_box['bbox'][3] 92 | else: 93 | vehicle_area = 0 94 | else: 95 | vehicle_area = 0 96 | # print(f"no vehicle box filter: {data['image']}") 97 | # miss_data.append(f"no vehicle box filter: {data['image']}") 98 | 99 | 100 | if human_area > area_thr and vehicle_area > area_thr: # 人车框都正常 101 | new_data.append(data) 102 | elif human_area > area_thr and vehicle_area == 0: # 只有人,人大于阈值,没有车 103 | new_data.append(data) 104 | elif vehicle_area > area_thr and human_area == 0: # 只有车,车大于阈值,没有人 105 | new_data.append(data) 106 | else: 107 | print(f"area filter:{data['image']}, {human_area}") 108 | 109 | 110 | print(f'num:{len(data_json)} vs {len(new_data)}') 111 | 112 | with open(save_path, 'w') as f: 113 | f.write(json.dumps(new_data, indent=2, ensure_ascii=False)) 114 | -------------------------------------------------------------------------------- /llava/serve/batch_inference_block.py: -------------------------------------------------------------------------------- 1 | import os, time, glob, json 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | import multiprocessing as mp 5 | import argparse 6 | import sys 7 | import torch 8 | 9 | from llava.serve.cli_final import infer_once 10 | 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import get_model_name_from_path 14 | 15 | 16 | def infer(model_path, model_base, data_list, i, cache_dir, best_view_map, data_path, local_image_data_path): 17 | # os.environ['CUDA_VISIBLE_DEVICES'] = f'{i}' 18 | 19 | torch.manual_seed(1234) 20 | 21 | disable_torch_init() 22 | 23 | model_name = get_model_name_from_path(model_path) 24 | tokenizer, model, processor, context_len = load_pretrained_model(model_path, model_base, model_name, 25 | False, True, 26 | device=f"cuda:{i}") 27 | if model.device == torch.device('cpu'): 28 | print('Detecting model loaded on CPU...') 29 | model = model.to(f"cuda:{i}") 30 | results = {} 31 | 32 | for scneario in tqdm(data_list, desc=f'GPU:{str(i)}'): 33 | scneario_res = [] 34 | for clip in os.listdir(os.path.join(data_path, scneario)): 35 | if os.path.exists(os.path.join(cache_dir, scneario, clip.replace('.jpg', '.json'))): 36 | res = json.load(open(os.path.join(cache_dir, scneario, clip.replace('.jpg', '.json')))) 37 | for key, value in res.items(): 38 | if key == 'labels': 39 | continue 40 | else: 41 | value = value.replace('<|startoftext|> ', '') 42 | value = value.replace('<|im_end|>', '') 43 | else: 44 | res = {} 45 | res['labels'] = [str(clip.replace('.jpg', ''))] 46 | retries = 0 47 | while retries < 5: 48 | try: 49 | global_image_path = os.path.join(data_path, scneario, clip) 50 | best_view = best_view_map[os.path.abspath(global_image_path)] 51 | local_image_path = os.path.join(local_image_data_path, scneario, clip) 52 | res_caption = infer_once(global_image_path, local_image_path, best_view, tokenizer, model, processor, context_len, model_name) 53 | break 54 | except Exception as e: 55 | print("Error: ", e, model.device) 56 | torch.cuda.empty_cache() 57 | retries += 1 58 | if retries >= 5: 59 | break 60 | 61 | res.update(res_caption) 62 | os.makedirs(os.path.join(cache_dir, scneario), exist_ok=True) 63 | with open(os.path.join(cache_dir, scneario, clip.replace('.jpg', '.json')), 'w') as f: 64 | f.write(json.dumps(res, indent=4)) 65 | 66 | scneario_res.append(res) 67 | results[scneario] = scneario_res 68 | 69 | return [results] 70 | 71 | 72 | if __name__ == '__main__': 73 | mp.set_start_method('spawn') 74 | # Create the parser 75 | parser = argparse.ArgumentParser(description='Process the paths and configurations.') 76 | 77 | # Add arguments 78 | parser.add_argument('--data-path', type=str, required=True, 79 | help='Path to the global image datasets directory.') 80 | parser.add_argument('--local-image-data-path', type=str, required=True, 81 | help='Path to the local image datasets directory.') 82 | parser.add_argument('--finetune-model', type=str, default=None, 83 | help='Path to the finetune model directory (if any).') 84 | parser.add_argument('--model-base', type=str, default=None, 85 | help='Path to the model base directory (if any).') 86 | parser.add_argument('--save-path', type=str, required=True, 87 | help='Path where the results will be saved.') 88 | parser.add_argument('--num-pool', type=int, default=1, 89 | help='Number of pools to use, equal to the number of GPUs.') 90 | parser.add_argument('--cache-dir', type=str, default="test_tmp_cache", help='Store inference cache.') 91 | parser.add_argument('--best-view-map', type=str, required=True, help='indicate the best view') 92 | # Parse the arguments 93 | args = parser.parse_args() 94 | 95 | 96 | json_data = list(os.listdir(args.data_path)) 97 | best_view_map = json.load(open(args.best_view_map)) 98 | process_num = int(len(json_data)/args.num_pool) + 1 99 | pool = Pool(processes=args.num_pool) 100 | json_data_splits = [json_data[i:i+process_num] for i in range(0, len(json_data), process_num)] 101 | 102 | results = [] 103 | 104 | for i, splits in enumerate(json_data_splits): 105 | # result = infer(finetune_model, splits[:2], i) 106 | result = pool.apply_async(infer, (args.finetune_model, args.model_base, splits, i, args.cache_dir, best_view_map, args.data_path, args.local_image_data_path)) 107 | results.append(result) 108 | pool.close() 109 | pool.join() 110 | 111 | traffic_list = {} 112 | for i in range(len(results)): 113 | curr_dict = results[i].get() 114 | traffic_list.update(curr_dict[0]) 115 | 116 | with open(args.save_path, 'w') as f: 117 | f.write(json.dumps(traffic_list, indent=2, ensure_ascii=False) + '\n') 118 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | torch.cuda.empty_cache() 91 | 92 | return super().forward( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | position_ids=position_ids, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | labels=labels, 99 | use_cache=use_cache, 100 | output_attentions=output_attentions, 101 | output_hidden_states=output_hidden_states, 102 | return_dict=return_dict 103 | ) 104 | 105 | @torch.no_grad() 106 | def generate( 107 | self, 108 | inputs: Optional[torch.Tensor] = None, 109 | images: Optional[torch.Tensor] = None, 110 | image_sizes: Optional[torch.Tensor] = None, 111 | **kwargs, 112 | ) -> Union[GenerateOutput, torch.LongTensor]: 113 | position_ids = kwargs.pop("position_ids", None) 114 | attention_mask = kwargs.pop("attention_mask", None) 115 | if "inputs_embeds" in kwargs: 116 | raise NotImplementedError("`inputs_embeds` is not supported") 117 | 118 | if images is not None: 119 | ( 120 | inputs, 121 | position_ids, 122 | attention_mask, 123 | _, 124 | inputs_embeds, 125 | _ 126 | ) = self.prepare_inputs_labels_for_multimodal( 127 | inputs, 128 | position_ids, 129 | attention_mask, 130 | None, 131 | None, 132 | images, 133 | image_sizes=image_sizes 134 | ) 135 | else: 136 | inputs_embeds = self.get_model().embed_tokens(inputs) 137 | 138 | return super().generate( 139 | position_ids=position_ids, 140 | attention_mask=attention_mask, 141 | inputs_embeds=inputs_embeds, 142 | **kwargs 143 | ) 144 | 145 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 146 | inputs_embeds=None, **kwargs): 147 | images = kwargs.pop("images", None) 148 | image_sizes = kwargs.pop("image_sizes", None) 149 | inputs = super().prepare_inputs_for_generation( 150 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 151 | ) 152 | if images is not None: 153 | inputs['images'] = images 154 | if image_sizes is not None: 155 | inputs['image_sizes'] = image_sizes 156 | return inputs 157 | 158 | AutoConfig.register("llava_llama", LlavaConfig) 159 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 160 | -------------------------------------------------------------------------------- /data_preprocess/extract_wts_test_frame_bbox_anno.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--root', type=str, default='/mnt/data/AICITY2024/', help='data root path') 8 | parser.add_argument('--save-folder', type=str, default='processed_anno', help='dirname for saving json file') 9 | 10 | args = parser.parse_args() 11 | 12 | video_path = os.path.join(args.root, 'WTS_DATASET_PUBLIC_TEST/videos/test/public') 13 | annotation_path = os.path.join(args.root, 'WTS_DATASET_PUBLIC_TEST/annotations/caption/test/public_challenge') 14 | bbox_path = os.path.join(args.root, 'WTS_DATASET_PUBLIC_TEST_BBOX/annotations/bbox_annotated/') 15 | 16 | video_with_bbox_results = dict() 17 | 18 | for item in tqdm(os.listdir(video_path)): 19 | if 'normal' in item: 20 | continue 21 | 22 | for view in ['overhead', 'vehicle']: 23 | current_view = os.path.join(video_path, item, f'{view}_view') 24 | 25 | caption_anno_path = os.path.join(annotation_path, item, f'{view}_view', f'{item}_caption.json') 26 | 27 | # vehicle bbox extraction 28 | if view == 'overhead': 29 | assert os.path.exists(caption_anno_path) 30 | try: 31 | vehicle_annotation = json.load(open(caption_anno_path))['event_phase'] 32 | except: 33 | continue 34 | start_time, end_time = None, None 35 | for phase in vehicle_annotation: 36 | if not start_time: 37 | start_time = float(phase['start_time']) 38 | else: 39 | start_time = min(float(phase['start_time']), start_time) 40 | if not end_time: 41 | end_time = float(phase['end_time']) 42 | else: 43 | end_time = max(float(phase['end_time']), end_time) 44 | 45 | for camera in os.listdir(current_view): 46 | camera_base = camera.replace('.mp4', '') 47 | video_with_bbox_results[os.path.join(current_view, camera)] = dict(start_time=start_time, end_time=end_time, ped_bboxes=dict(), veh_bboxes=dict(), phase_number=dict()) 48 | pedestrian_bbox_anno_path = os.path.join(bbox_path, 'pedestrian', 'test/public', item, f'{view}_view', f'{camera_base}_bbox.json') 49 | if os.path.exists(pedestrian_bbox_anno_path): 50 | pedestrian_bbox = json.load(open(pedestrian_bbox_anno_path))['annotations'] 51 | for bbox in pedestrian_bbox: 52 | video_with_bbox_results[os.path.join(current_view, camera)]['ped_bboxes'][bbox['image_id']] = bbox['bbox'] 53 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 54 | 55 | vehicle_bbox_anno_path = os.path.join(bbox_path, 'vehicle', 'test/public', item, f'{view}_view', f'{camera_base}_bbox.json') 56 | if os.path.exists(vehicle_bbox_anno_path): 57 | vehicle_bbox = json.load(open(vehicle_bbox_anno_path))['annotations'] 58 | for bbox in vehicle_bbox: 59 | video_with_bbox_results[os.path.join(current_view, camera)]['veh_bboxes'][bbox['image_id']] = bbox['bbox'] 60 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 61 | 62 | 63 | for item in tqdm(os.listdir(os.path.join(video_path, 'normal_trimmed'))): 64 | ori_time = item 65 | item = f'normal_trimmed/{item}' 66 | for view in ['overhead', 'vehicle']: 67 | current_view = os.path.join(video_path, item, f'{view}_view') 68 | 69 | caption_anno_path = os.path.join(annotation_path, item, f'{view}_view', f'{ori_time}_caption.json') 70 | 71 | # vehicle bbox extraction 72 | if view == 'overhead': 73 | assert os.path.exists(caption_anno_path), caption_anno_path 74 | try: 75 | vehicle_annotation = json.load(open(caption_anno_path))['event_phase'] 76 | except: 77 | continue 78 | start_time, end_time = None, None 79 | for phase in vehicle_annotation: 80 | if not start_time: 81 | start_time = float(phase['start_time']) 82 | else: 83 | start_time = min(float(phase['start_time']), start_time) 84 | if not end_time: 85 | end_time = float(phase['end_time']) 86 | else: 87 | end_time = max(float(phase['end_time']), end_time) 88 | 89 | for camera in os.listdir(current_view): 90 | camera_base = camera.replace('.mp4', '') 91 | video_with_bbox_results[os.path.join(current_view, camera)] = dict(start_time=start_time, end_time=end_time, ped_bboxes=dict(), veh_bboxes=dict(), phase_number=dict()) 92 | pedestrian_bbox_anno_path = os.path.join(bbox_path, 'pedestrian', 'test/public', item, f'{view}_view', f'{camera_base}_bbox.json') 93 | 94 | if os.path.exists(pedestrian_bbox_anno_path): 95 | pedestrian_bbox = json.load(open(pedestrian_bbox_anno_path))['annotations'] 96 | for bbox in pedestrian_bbox: 97 | video_with_bbox_results[os.path.join(current_view, camera)]['ped_bboxes'][bbox['image_id']] = bbox['bbox'] 98 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 99 | 100 | vehicle_bbox_anno_path = os.path.join(bbox_path, 'vehicle', 'test/public', item, f'{view}_view', f'{camera_base}_bbox.json') 101 | if os.path.exists(vehicle_bbox_anno_path): 102 | vehicle_bbox = json.load(open(vehicle_bbox_anno_path))['annotations'] 103 | for bbox in vehicle_bbox: 104 | video_with_bbox_results[os.path.join(current_view, camera)]['veh_bboxes'][bbox['image_id']] = bbox['bbox'] 105 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 106 | 107 | os.makedirs(args.save_folder, exist_ok=True) 108 | with open(os.path.join(args.save_folder, 'wts_test_all_video_with_bbox_anno_first_frame.json'), 'w') as f: 109 | f.write(json.dumps(video_with_bbox_results, indent=4)) -------------------------------------------------------------------------------- /data_preprocess/extract_wts_frame_bbox_anno.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--root', type=str, default='/mnt/data/AICITY2024/', help='data root path') 8 | parser.add_argument('--split', type=str, default='train') 9 | parser.add_argument('--save-folder', type=str, default='processed_anno', help='dirname for saving json file') 10 | 11 | args = parser.parse_args() 12 | 13 | video_path = os.path.join(args.root, 'WTS/videos', args.split) 14 | annotation_path = os.path.join(args.root, 'WTS/annotations/caption', args.split) 15 | bbox_path = os.path.join(args.root, 'WTS/annotations/bbox_annotated') 16 | 17 | video_with_bbox_results = dict() 18 | 19 | for item in tqdm(os.listdir(video_path)): 20 | if 'normal' in item: 21 | continue 22 | 23 | for view in ['overhead', 'vehicle']: 24 | current_view = os.path.join(video_path, item, f'{view}_view') 25 | 26 | caption_anno_path = os.path.join(annotation_path, item, f'{view}_view', f'{item}_caption.json') 27 | 28 | # vehicle bbox extraction 29 | if view == 'overhead': 30 | assert os.path.exists(caption_anno_path), f'{caption_anno_path} not exists' 31 | try: 32 | vehicle_annotation = json.load(open(caption_anno_path))['event_phase'] 33 | except: 34 | continue 35 | start_time, end_time = None, None 36 | for phase in vehicle_annotation: 37 | if not start_time: 38 | start_time = float(phase['start_time']) 39 | else: 40 | start_time = min(float(phase['start_time']), start_time) 41 | if not end_time: 42 | end_time = float(phase['end_time']) 43 | else: 44 | end_time = max(float(phase['end_time']), end_time) 45 | 46 | for camera in os.listdir(current_view): 47 | camera_base = camera.replace('.mp4', '') 48 | video_with_bbox_results[os.path.join(current_view, camera)] = dict(start_time=start_time, end_time=end_time, ped_bboxes=dict(), veh_bboxes=dict(), phase_number=dict()) 49 | pedestrian_bbox_anno_path = os.path.join(bbox_path, 'pedestrian', args.split, item, f'{view}_view', f'{camera_base}_bbox.json') 50 | 51 | if os.path.exists(pedestrian_bbox_anno_path): 52 | pedestrian_bbox = json.load(open(pedestrian_bbox_anno_path))['annotations'] 53 | for bbox in pedestrian_bbox: 54 | video_with_bbox_results[os.path.join(current_view, camera)]['ped_bboxes'][bbox['image_id']] = bbox['bbox'] 55 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 56 | 57 | vehicle_bbox_anno_path = os.path.join(bbox_path, 'vehicle', args.split, item, f'{view}_view', f'{camera_base}_bbox.json') 58 | if os.path.exists(vehicle_bbox_anno_path): 59 | vehicle_bbox = json.load(open(vehicle_bbox_anno_path))['annotations'] 60 | for bbox in vehicle_bbox: 61 | video_with_bbox_results[os.path.join(current_view, camera)]['veh_bboxes'][bbox['image_id']] = bbox['bbox'] 62 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 63 | 64 | 65 | for item in tqdm(os.listdir(os.path.join(video_path, 'normal_trimmed'))): 66 | ori_time = item 67 | item = f'normal_trimmed/{item}' 68 | for view in ['overhead', 'vehicle']: 69 | current_view = os.path.join(video_path, item, f'{view}_view') 70 | 71 | caption_anno_path = os.path.join(annotation_path, item, f'{view}_view', f'{ori_time}_caption.json') 72 | 73 | # vehicle bbox extraction 74 | if view == 'overhead': 75 | assert os.path.exists(caption_anno_path), caption_anno_path 76 | try: 77 | vehicle_annotation = json.load(open(caption_anno_path))['event_phase'] 78 | except: 79 | continue 80 | start_time, end_time = None, None 81 | for phase in vehicle_annotation: 82 | if not start_time: 83 | start_time = float(phase['start_time']) 84 | else: 85 | start_time = min(float(phase['start_time']), start_time) 86 | if not end_time: 87 | end_time = float(phase['end_time']) 88 | else: 89 | end_time = max(float(phase['end_time']), end_time) 90 | 91 | for camera in os.listdir(current_view): 92 | camera_base = camera.replace('.mp4', '') 93 | video_with_bbox_results[os.path.join(current_view, camera)] = dict(start_time=start_time, end_time=end_time, ped_bboxes=dict(), veh_bboxes=dict(), phase_number=dict()) 94 | pedestrian_bbox_anno_path = os.path.join(bbox_path, 'pedestrian', args.split, item, f'{view}_view', f'{camera_base}_bbox.json') 95 | 96 | if os.path.exists(pedestrian_bbox_anno_path): 97 | pedestrian_bbox = json.load(open(pedestrian_bbox_anno_path))['annotations'] 98 | for bbox in pedestrian_bbox: 99 | video_with_bbox_results[os.path.join(current_view, camera)]['ped_bboxes'][bbox['image_id']] = bbox['bbox'] 100 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 101 | 102 | vehicle_bbox_anno_path = os.path.join(bbox_path, 'vehicle', args.split, item, f'{view}_view', f'{camera_base}_bbox.json') 103 | if os.path.exists(vehicle_bbox_anno_path): 104 | vehicle_bbox = json.load(open(vehicle_bbox_anno_path))['annotations'] 105 | for bbox in vehicle_bbox: 106 | video_with_bbox_results[os.path.join(current_view, camera)]['veh_bboxes'][bbox['image_id']] = bbox['bbox'] 107 | video_with_bbox_results[os.path.join(current_view, camera)]['phase_number'][bbox['image_id']] = bbox['phase_number'] 108 | 109 | os.makedirs(args.save_folder, exist_ok=True) 110 | with open(os.path.join(args.save_folder, f'wts_{args.split}_all_video_with_bbox_anno_first_frame.json'), 'w') as f: 111 | f.write(json.dumps(video_with_bbox_results, indent=4)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AICITY2024_Track2_AliOpenTrek_CityLLaVA 2 | 3 | 🏆 **The 1st Place** Solution to The 8th NVIDIA AI City Challenge (CVPR 2024 workshop) Track 2: [CityLLaVA: Efficient Fine-Tuning for VLMs in City Scenario](https://arxiv.org/abs/2405.03194). 4 | 5 | ![1713757322703](image/README/1713757322703.png) 6 | 7 | ## Leaderboard 8 | 9 | | **TeamName** | **MRR Score** | **Rank** | 10 | | :-------------------------: | :-----------------: | :------------: | 11 | | **AliOpenTrek(Ours)** | **33.4308** | **1** | 12 | | AIO_ISC | 32.8877 | 2 | 13 | | Lighthouse | 32.3006 | 3 | 14 | 15 | ## Prepare 16 | 17 | 1. Install Package 18 | 19 | ```Shell 20 | conda create -n cityllava python=3.10 -y 21 | conda activate cityllava 22 | cd AICITY2024_Track2_AliOpenTrek_CityLLaVA/ 23 | pip install --upgrade pip # enable PEP 660 support 24 | pip install -e . 25 | pip install flash-attn --no-build-isolation 26 | ``` 27 | 28 | ## structures 29 | 30 | ### Data Preparation 31 | 32 | Firstly change the directory to `data_preprocess` and create the `data` directory. 33 | 34 | ``` 35 | cd data_preprocess 36 | mkdir ./data 37 | ``` 38 | 39 | Please download the [wts-dataset](https://github.com/woven-visionai/wts-dataset). Then, put the datasets under `./data`. After unzip the datasets, the directory structure should be like this: 40 | 41 | ``` 42 | . 43 | ├── data 44 | │ ├── BDD_PC_5k 45 | │ │ ├── annotations 46 | │ │ │ ├── bbox_annotated 47 | │ │ │ ├── bbox_generated 48 | │ │ │ └── caption 49 | │ │ └── videos 50 | │ ├── WTS 51 | │ │ ├── annotations 52 | │ │ │ ├── bbox_annotated 53 | │ │ │ ├── bbox_generated 54 | │ │ │ └── caption 55 | │ │ └── videos 56 | │ └── test_part 57 | | ├── view_used_as_main_reference_for_multiview_scenario.csv 58 | │ ├── WTS_DATASET_PUBLIC_TEST 59 | │ └── WTS_DATASET_PUBLIC_TEST_BBOX 60 | └── ... # python and shell scripts 61 | ``` 62 | 63 | Then run the following script to process the test data: 64 | 65 | ``` 66 | bash prepare_data_test.sh 67 | ``` 68 | After this script is excuted, all the test data is prepared. You can download the fintuned model and run the inference step directly. 69 | 70 | Run the following script to process the train data: 71 | 72 | ``` 73 | bash prepare_data_train.sh 74 | ``` 75 | Note that the Openai or Qwen API is required in "prepare_data_train.sh". You should modify the API_KEY in this script. 76 | 77 | After the execution, the folder structure should be like this: 78 | 79 | ``` 80 | . 81 | ├── data 82 | │ ├── BDD_PC_5k 83 | │ │ ├── annotations 84 | │ │ │ ├── bbox_annotated 85 | │ │ │ ├── bbox_generated 86 | │ │ │ └── caption 87 | │ │ ├── bbox_global # BDD global views 88 | │ │ │ ├── train 89 | │ │ │ └── val 90 | │ │ ├── bbox_local # BDD local views 91 | │ │ │ ├── train 92 | │ │ │ └── val 93 | │ │ └── videos 94 | │ ├── WTS 95 | │ │ ├── annotations 96 | │ │ │ ├── bbox_annotated 97 | │ │ │ ├── bbox_generated 98 | │ │ │ └── caption 99 | │ │ ├── bbox_global # WTS global views 100 | │ │ │ ├── train 101 | │ │ │ └── val 102 | │ │ ├── bbox_local # BDD local views 103 | │ │ │ ├── train 104 | │ │ │ └── val 105 | │ │ └── videos 106 | │ └── test_part 107 | | ├── view_used_as_main_reference_for_multiview_scenario.csv 108 | │ ├── WTS_DATASET_PUBLIC_TEST 109 | │ │ ├──bbox_global/test/public # WTS Test Images 110 | │ │ ├──bbox_local/test/public 111 | │ │ └──external/BDD_PC_5K 112 | │ │ ├──bbox_global/test/public # BDD Test Images 113 | │ │ └──bbox_local/test/public 114 | │ └── WTS_DATASET_PUBLIC_TEST_BBOX 115 | ├── processed_anno 116 | │ ├── frame_bbox_anno 117 | │ │ ├── bdd_test_all_video_with_bbox_anno_first_frame.json 118 | │ │ ├── bdd_train_all_video_with_bbox_anno_first_frame.json 119 | │ │ ├── bdd_val_all_video_with_bbox_anno_first_frame.json 120 | │ │ ├── wts_test_all_video_with_bbox_anno_first_frame.json 121 | │ │ ├── wts_train_all_video_with_bbox_anno_first_frame.json 122 | │ │ └── wts_val_all_video_with_bbox_anno_first_frame.json 123 | │ ├── llava_format 124 | │ │ ├── wts_bdd_train.json 125 | │ │ └── wts_bdd_val.json 126 | │ ├──best_view_for_test.json 127 | │ └──perspective_test_images.json 128 | └── ... # python and shell scripts 129 | ``` 130 | 131 | Then the processed annotations could be found under `./processed_anno`, and the train json is: 132 | 133 | ``` 134 | './data/processed_anno/llava_format/wts_bdd_llava_qa_train_stage_filted_checked.json' 135 | ``` 136 | 137 | ## Block-Expansion 138 | 139 | We use the block expansion to fine-tune the VLMs. 8~16 blocks are suggested for balancing the performance and efficiency. We add 12 blcoks to the original llava-1.6-34b. the llava-1.6-34b-12block model could be created by these steps: 140 | 141 | 1. Download the [llava-1.6-34b](https://huggingface.co/liuhaotian/llava-v1.6-34b) model to `./models`, and add block with this script: 142 | 143 | ``` 144 | python block_expansion_llava_1_6.py 145 | ``` 146 | 147 | 2. Copy the `*.json` and `tokenizer.model` form `./models/llava-v1.6-34b` to `./models/llava-v1.6-34b-12block`; 148 | 3. Modify the `num_hidden_layers=72` (new_layer_nums= original_layer_nums+block_layer_nums) in `config.json` of the llava-1.6-34b-12block model. 149 | 150 | ## Train 151 | 152 | We use 8xA100 GPUs for fine-tuning. The training process takes approximately 8 hours by this script: 153 | 154 | ``` 155 | bash scripts/finetune_block_bigsmall.sh 156 | ``` 157 | 158 | The fine-tuned model could be download [here](https://modelscope.cn/models/AliOpenTrek/CityLLaVA). 159 | 160 | ## Inference 161 | 162 | Firstly, you should check the parameters defined at `./scripts/inference.sh`, ensure that all essential files and model exist. 163 | 164 | Now you can do inference on WTS_TEST_SET: 165 | 166 | ``` 167 | bash scripts/inference.sh 168 | ``` 169 | 170 | ## Evaluation 171 | 172 | We use the [wts-dataset](https://github.com/woven-visionai/wts-dataset) for evaluation. 173 | 174 | ## Citation 175 | 176 | If you find CityLLaVA useful for your research and applications, please cite using this BibTeX: 177 | 178 | ```bibtex 179 | @misc{duan2024cityllava, 180 | title={CityLLaVA: Efficient Fine-Tuning for VLMs in City Scenario}, 181 | url={https://github.com/qingchunlizhi/AICITY2024_Track2_AliOpenTrek_CityLLaVA}, 182 | author={Zhizhao Duan, Hao Cheng, Duo Xu, Xi Wu, Xiangxie Zhang, Xi Ye, and Zhen Xie}, 183 | year={2024}, 184 | eprint={2405.03194}, 185 | archivePrefix={arXiv}, 186 | primaryClass={cs.CV} 187 | } 188 | ``` 189 | 190 | ## Acknowledgement 191 | 192 | - CityLLaVA is built with reference to the code of the following projects: [LLaVA](https://github.com/haotian-liu/LLaVA) and [LLaMA-Pro](https://github.com/TencentARC/LLaMA-Pro.git). Thanks for their awesome work! 193 | -------------------------------------------------------------------------------- /data_preprocess/generate_test_frames.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import cv2 4 | import shutil 5 | from tqdm import tqdm 6 | 7 | number_phrase_map = { 8 | 'prerecognition': '0', 9 | 'recognition': '1', 10 | 'judgement': '2', 11 | 'action': '3', 12 | 'avoidance': '4' 13 | } 14 | 15 | phrase_number_map = {v:k for k, v in number_phrase_map.items()} 16 | 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--root', type=str, default='/mnt/data/AICITY2024/image_for_test', help='data root path') 21 | parser.add_argument('--best-view-anno', type=str, required=True, help='best view anno for selecting frames for test') 22 | parser.add_argument('--wts-test-folder', type=str, required=True) 23 | parser.add_argument('--bdd-test-folder', type=str, required=True) 24 | parser.add_argument('--save-folder', type=str, default='./processed_anno') 25 | 26 | args = parser.parse_args() 27 | 28 | best_view = json.load(open(args.best_view_anno)) 29 | root = os.path.join(args.root, 'bbox_global') 30 | 31 | camera_path_mapping = dict() 32 | 33 | global_image_path = os.path.join(args.wts_test_folder) 34 | for event in os.listdir(global_image_path): 35 | if 'normal_trimmed' in event: 36 | continue 37 | for view in os.listdir(os.path.join(global_image_path, event)): 38 | parent_path = os.path.join(global_image_path, event, view) 39 | for camera in os.listdir(parent_path): 40 | camera_path_mapping[camera] = os.path.join(parent_path, camera) 41 | 42 | 43 | for event in os.listdir(os.path.join(global_image_path, 'normal_trimmed')): 44 | for view in os.listdir(os.path.join(global_image_path, 'normal_trimmed', event)): 45 | parent_path = os.path.join(global_image_path, 'normal_trimmed', event, view) 46 | for camera in os.listdir(parent_path): 47 | camera_path_mapping[camera] = os.path.join(parent_path, camera) 48 | 49 | perspective = dict() 50 | for key, value in tqdm(best_view.items()): 51 | os.makedirs(os.path.join(root, key), exist_ok=True) 52 | os.makedirs(os.path.join(root.replace('bbox_global', 'bbox_local'), key), exist_ok=True) 53 | for label, segment in phrase_number_map.items(): 54 | if 'video' in key: 55 | image_name = f'{key}_{segment}.jpg' 56 | try: 57 | shutil.copy(os.path.join(args.bdd_test_folder, image_name), os.path.join(root, key, f'{label}.jpg')) 58 | shutil.copy(os.path.join(args.bdd_test_folder.replace('bbox_global', 'bbox_local'), image_name), os.path.join(root.replace('bbox_global', 'bbox_local'), key, f'{label}.jpg')) 59 | except: 60 | print(f'{key}_{segment} not exist!') 61 | else: 62 | image_name = f'{label}_{segment}.jpg' 63 | try: 64 | shutil.copy(os.path.join(camera_path_mapping[value.replace('.mp4', '')], image_name), os.path.join(root, key, f'{label}.jpg')) 65 | shutil.copy(os.path.join(camera_path_mapping[value.replace('.mp4', '')], image_name).replace('bbox_global', 'bbox_local'), os.path.join(root.replace('bbox_global', 'bbox_local'), key, f'{label}.jpg')) 66 | except: 67 | print(f'{value.replace(".mp4", "")}/{image_name} not exist!') 68 | 69 | if os.path.exists(os.path.join(root, key, f'{label}.jpg')): 70 | perspective[os.path.abspath(os.path.join(root, key, f'{label}.jpg'))] = 'vehicle' if 'vehicle' in key or 'video' in key else 'overhead' 71 | 72 | 73 | def find_closest_number(target, arr): 74 | min_diff = float('inf') 75 | closest_num = None 76 | for num in arr: 77 | diff = abs(target - int(num)) 78 | 79 | if diff < min_diff: 80 | min_diff = diff 81 | closest_num = num 82 | 83 | return closest_num 84 | 85 | no_bbox_frames_ones = set() 86 | for item in os.listdir(root): 87 | images = os.listdir(os.path.join(root, item)) 88 | if len(images) == 0: 89 | no_bbox_frames_ones.add(item) 90 | continue 91 | if len(images) != 5: 92 | for image in ['0.jpg', '1.jpg', '2.jpg', '3.jpg', '4.jpg']: 93 | if image not in images: 94 | closest_image = find_closest_number(int(image.split('.')[0]), [int(x.split('.')[0]) for x in images]) 95 | source_path = os.path.join(root, item, f'{closest_image}.jpg') 96 | target_path = os.path.join(root, item, image) 97 | shutil.copy(source_path, target_path) 98 | 99 | perspective[os.path.abspath(target_path)] = perspective[os.path.abspath(source_path)] 100 | 101 | source_path = os.path.join(root.replace('bbox_global', 'bbox_local'), item, f'{closest_image}.jpg') 102 | target_path = os.path.join(root.replace('bbox_global', 'bbox_local'), item, image) 103 | shutil.copy(source_path, target_path) 104 | 105 | assert len(os.listdir(os.path.join(root, item))) == 5 106 | 107 | 108 | def extract_frames(source_video, source_anno, save_folder): 109 | with open(source_anno, 'r') as f: 110 | data = json.load(f) 111 | 112 | cap = cv2.VideoCapture(source_video) 113 | fps = cap.get(cv2.CAP_PROP_FPS) 114 | 115 | for event in data['event_phase']: 116 | label = event['labels'][0] 117 | start_time = float(event['start_time']) 118 | 119 | frame_number = int(start_time * fps) 120 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) 121 | 122 | ret, frame = cap.read() 123 | if ret: 124 | cv2.imwrite(f'{save_folder}/{label}.jpg', frame) 125 | cv2.imwrite(f'{save_folder.replace("bbox_global", "bbox_local")}/{label}.jpg', frame) 126 | else: 127 | print(f"Error: Unable to extract frame for event {label}, source video: {source_video}") 128 | 129 | cap.release() 130 | 131 | for item in no_bbox_frames_ones: 132 | if 'video' in item: 133 | source_video = f'{args.bdd_test_folder.replace("bbox_global", "videos")}/{item}.mp4' 134 | source_anno = f'{args.bdd_test_folder.replace("bbox_global", "annotations/caption")}_challenge/{item}_caption.json' 135 | view = 'vehicle' 136 | else: 137 | best_view_video = best_view[item] 138 | if 'normal' in best_view_video: 139 | if 'vehicle' in best_view_video: 140 | source_video = f'{args.wts_test_folder.replace("bbox_global", "videos")}/normal_trimmed/{item}/vehicle_view/{best_view_video}' 141 | source_anno = f'{args.wts_test_folder.replace("bbox_global", "annotations/caption")}_challenge/normal_trimmed/{item}/vehicle_view/{item}_caption.json' 142 | view = 'vehicle' 143 | else: 144 | source_video = f'{args.wts_test_folder.replace("bbox_global", "videos")}/normal_trimmed/{item}/overhead_view/{best_view_video}' 145 | source_anno = f'{args.wts_test_folder.replace("bbox_global", "annotations/caption")}_challenge/normal_trimmed/{item}/overhead_view/{item}_caption.json' 146 | view = 'overhead' 147 | else: 148 | if 'vehicle' in best_view_video: 149 | source_video = f'{args.wts_test_folder.replace("bbox_global", "videos")}/{item}/vehicle_view/{best_view_video}' 150 | source_anno = f'{args.wts_test_folder.replace("bbox_global", "annotations/caption")}_challenge/{item}/vehicle_view/{item}_caption.json' 151 | view = 'vehicle' 152 | else: 153 | source_video = f'{args.wts_test_folder.replace("bbox_global", "videos")}/{item}/overhead_view/{best_view_video}' 154 | source_anno = f'{args.wts_test_folder.replace("bbox_global", "annotations/caption")}_challenge/{item}/overhead_view/{item}_caption.json' 155 | view = 'overhead' 156 | 157 | save_folder = os.path.join(root, item) 158 | extract_frames(source_video, source_anno, save_folder) 159 | for image in ['0.jpg', '1.jpg', '2.jpg', '3.jpg', '4.jpg']: 160 | target_path = os.path.abspath(os.path.join(root, item, image)) 161 | assert os.path.exists(target_path) 162 | perspective[target_path] = view 163 | 164 | for item in os.listdir(root): 165 | images = os.listdir(os.path.join(root, item)) 166 | if len(images) != 5: 167 | print(item) 168 | 169 | with open(os.path.join(args.save_folder, 'perspective_test_images.json'), 'w') as f: 170 | f.write(json.dumps(perspective, indent=4)) -------------------------------------------------------------------------------- /data_preprocess/shortQA_split.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | import json 3 | import tqdm 4 | import multiprocessing 5 | import os 6 | import argparse 7 | 8 | import dashscope 9 | from openai import OpenAI 10 | 11 | # Short QA Construction 12 | # 1) utilizes llm to categorize each sentence of the descriptions into predefined dimensions. 13 | 14 | 15 | # * here we use qwen (dashscope api) instead of gpt-4 or chat-gpt, you can easily change it with openAI api 16 | # For prerequisites running the following sample, visit https://help.aliyun.com/document_detail/611472.html 17 | def call_with_messages(content, model_type, key): 18 | if model_type == 'Qwen': 19 | dashscope.api_key = key 20 | messages = [{'role': 'system', 'content': 'You are a helpful assistant.'}, 21 | {'role': 'user', 'content': content 22 | }] 23 | 24 | response = dashscope.Generation.call( 25 | dashscope.Generation.Models.qwen_plus, 26 | messages=messages, 27 | result_format='message', # set the result to be "message" format. 28 | ) 29 | if response.status_code == HTTPStatus.OK: 30 | # print(response) 31 | response = response.output['choices'][0]['message']['content'] 32 | print(response) 33 | return response 34 | 35 | else: 36 | print('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( 37 | response.request_id, response.status_code, 38 | response.code, response.message 39 | )) 40 | return None 41 | 42 | elif model_type == 'Openai': 43 | client = OpenAI(api_key = key) 44 | chat_completion = client.chat.completions.create( 45 | messages=[ 46 | { 47 | "role": "user", 48 | "content": content, 49 | } 50 | ], 51 | model="gpt-4", 52 | ) 53 | response = chat_completion.choices[0].message.content 54 | return response 55 | 56 | 57 | def classify_single_caption(caption_text, caption_type, model_type, api_key): 58 | # 将长描述按照句号进行拆分 59 | caption_text += ' ' 60 | caption_list = caption_text.split('. ') 61 | 62 | for i, c in enumerate(caption_list): 63 | if len(caption_list[i]) > 0: 64 | caption_list[i] = "{}.{}.".format(str(i + 1), c) 65 | new_caption_text = '\n'.join(caption_list) 66 | # print(new_caption_text) 67 | 68 | if caption_type == "pedestrian": 69 | content = ("Please select the most appropriate label for each descriptive text from the following options, and format the output by providing the text index followed by the letter a, b, c, d, or e. Each selection should be on a new line.\n" 70 | "Option a. Description of the pedestrian's age, height and clothing.\n" 71 | "Option b. Description of the orientation and relative position relationship between pedestrians and vehicles.\n" 72 | "Option c. Description of the pedestrian's line of sight direction and movement status. \n" 73 | "Option d. Pedestrians' surrounding environment, weather conditions, and road conditions. \n" 74 | "Option e. A description of whether pedestrians have potential risks or accidents, summarizing the pedestrian situation. \n" 75 | "here is the descriptive text: \n" + new_caption_text) 76 | # Previous versions used 77 | content_cn = ("请为每一条描述文本分别从以下选项中挑选一个最符合的主题,最终按行输出文本序号和abcde其中一个字母。\n" 78 | "选项a.行人的年龄、身高等基本特征和穿着描述。\n" 79 | "选项b.行人的与车辆的朝向、相对位置关系的描述。\n" 80 | "选项c.行人的视线方向和运动状态。\n" 81 | "选项d.行人所处周边环境情况、天气情况、道路情况。\n" 82 | "选项e.行人是否存在潜在风险或是否发生事故的相关描述,对行人情况的总结概括。\n" 83 | "描述文本: \n" + new_caption_text) 84 | else: 85 | content = ("Please select the most appropriate label for each descriptive text from the following options, and format the output by providing the text index followed by the letter a, b, c, d, or e. Each selection should be on a new line.\n" 86 | "Option a. Description of the orientation and relative position relationship between pedestrians and vehicles. \n" 87 | "Option b. Description of the vehicle's driving status and speed. \n" 88 | "Option c. Description of the pedestrian' age, height and clothing. \n" 89 | "Option d. Description of the surrounding environment, weather conditions, and road conditions. \n" 90 | "Option e. A summary of the vehicle's situation.\n" 91 | "here is the descriptive text: \n" + new_caption_text) 92 | # Previous versions used 93 | content_cn = ("请为每一条描述文本分别从以下选项中挑选一个最符合的主题,最终按行输出文本序号和abcd其中一个字母。\n" 94 | "选项a.车辆与行人的相对位置关系等描述\n" 95 | "选项b.车辆的行驶状态、速度的描述。\n" 96 | "选项c.对行人的年龄、身高等基本特征和穿着描述。\n" 97 | "选项d.车辆所处周边环境情况、天气情况、道路情况。\n" 98 | "选项e.对车辆情况的总结性概括。\n" 99 | "描述文本: \n" + new_caption_text) 100 | 101 | print(content) 102 | response = call_with_messages(content, model_type, api_key) 103 | return response 104 | 105 | def classify_process(input_data_list, save_file, caption_type, model_type, api_key): 106 | w = open(save_file, 'w', encoding='utf-8') 107 | for data in tqdm.tqdm(input_data_list): 108 | id = data['id'] 109 | try: 110 | conversations = data['conversations'] 111 | pedestrian_caption = conversations[1]['value'] 112 | vehicle_caption = conversations[3]['value'] 113 | 114 | if caption_type == "vehicle": 115 | # 描述拆分 116 | response = classify_single_caption(vehicle_caption, "vehicle", model_type, api_key) 117 | data['vehicle_response'] = response 118 | 119 | else: 120 | response = classify_single_caption(pedestrian_caption, "pedestrian", model_type, api_key) 121 | data['pedestrian_response'] = response 122 | 123 | w.write(json.dumps(data, ensure_ascii=False) + '\n') 124 | w.flush() 125 | except Exception as e: 126 | print('{}, {}'.format(id, e)) 127 | 128 | if __name__ == '__main__': 129 | # test_caption = "The pedestrian is a male in his 10s, with a height of 160 cm. He is wearing a yellow T-shirt and black slacks. It is a weekday in an urban area with clear weather and dark brightness. The road surface is dry and level, made of asphalt. The pedestrian is standing diagonally to the left in front of a moving vehicle, which is far away. His body is perpendicular to the vehicle and to the right. His line of sight indicates that he is crossing the road. He is closely watching his destination while unaware of the vehicle. The pedestrian's speed is slow, and he intends to cross immediately in front of or behind the vehicle. The road he is on is a main road with one-way traffic and two lanes. Sidewalks are present on both sides" 130 | # classify_single_caption(test_caption, "pedestrian") 131 | 132 | args = argparse.ArgumentParser() 133 | args.add_argument('--model', type=str, default='Qwen', help='Choose your LLM') 134 | args.add_argument('--api-key', type=str, required=True, help='Your API key for chosen LLM') 135 | 136 | input_file = "./processed_anno/llava_format/wts_bdd_train.json" 137 | save_file = "./processed_anno/caption_split/caption_split.json" 138 | 139 | if not os.path.exists("./processed_anno/caption_split"): 140 | os.makedirs("./processed_anno/caption_split") 141 | 142 | num_works = 1 143 | endata_multiprocess = [] 144 | with open(input_file, 'r') as f: 145 | input_data_json = json.load(f) 146 | # input_data_json = input_data_json[0:10] 147 | for i in range(num_works): 148 | endata_multiprocess.append([]) 149 | for i in range(len(input_data_json)): 150 | endata_multiprocess[i%num_works].append(input_data_json[i]) 151 | 152 | pool = multiprocessing.Pool(processes=num_works) 153 | tasks = [] 154 | for i in range(num_works): 155 | tasks.append((endata_multiprocess[i], save_file.replace(".json", "_{}_{}.json".format("vehicle", i)), "vehicle", args.model, args.api_key)) 156 | tasks.append((endata_multiprocess[i], save_file.replace(".json", "_{}_{}.json".format("pedestrian", i)), "pedestrian", args.model, args.api_key)) 157 | pool.starmap(classify_process, tasks) 158 | pool.close() 159 | pool.join() 160 | 161 | -------------------------------------------------------------------------------- /llava/serve/cli_final.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | import json 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | phase_number_map = { 18 | '0': 'prerecognition', 19 | '1': 'recognition', 20 | '2': 'judgement', 21 | '3': 'action', 22 | '4': 'avoidance' 23 | } 24 | 25 | def load_image(image_file): 26 | if image_file.startswith('http://') or image_file.startswith('https://'): 27 | response = requests.get(image_file) 28 | image = Image.open(BytesIO(response.content)).convert('RGB') 29 | else: 30 | image = Image.open(image_file).convert('RGB') 31 | return image 32 | 33 | 34 | def infer_once(input_file, local_image_path, best_view, tokenizer, model, processor, context_len, model_name, ori_conv_mode='chatml_direct'): 35 | # Model 36 | if "llama-2" in model_name.lower(): 37 | conv_mode = "llava_llama_2" 38 | elif "mistral" in model_name.lower(): 39 | conv_mode = "mistral_instruct" 40 | elif "v1.6-34b" in model_name.lower(): 41 | conv_mode = "chatml_direct" 42 | elif "v1" in model_name.lower(): 43 | conv_mode = "llava_v1" 44 | elif "mpt" in model_name.lower(): 45 | conv_mode = "mpt" 46 | else: 47 | conv_mode = "llava_v0" 48 | 49 | if ori_conv_mode is not None and conv_mode != ori_conv_mode: 50 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, ori_conv_mode, ori_conv_mode)) 51 | else: 52 | ori_conv_mode = conv_mode 53 | print('conv_model: ', ori_conv_mode, model.device) 54 | conv = conv_templates[ori_conv_mode].copy() 55 | ######## notice!!!!! 56 | # conv.system = """<|im_start|>system 57 | # Answer the questions. Notice that you should provide answers as detailed as possible.""" 58 | if "mpt" in model_name.lower(): 59 | roles = ('user', 'assistant') 60 | else: 61 | roles = conv.roles 62 | image = load_image(input_file) 63 | image_size = image.size 64 | # Similar operation in model_worker.py 65 | # local_image_path = input_file.replace('bbox_image_val_100', 'bbox_image_val_cropped_scale1_5_100') 66 | # local_image_path = input_file.replace('bbox_image_test_new', 'bbox_image_test_cropped_scale1_5') 67 | local_image = Image.open(local_image_path).convert('RGB') 68 | # Similar operation in model_worker.py 69 | image_tensor = process_images([image], processor, model.config, [local_image]) 70 | if type(image_tensor) is list: 71 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 72 | else: 73 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 74 | 75 | # best_view = best_view_test[input_file.split('/')[-2]] 76 | # if 'BDD' in best_view: 77 | stage = phase_number_map[str(os.path.basename(input_file).split('.')[0])] 78 | if 'vehicle' in best_view: 79 | # if 'vehicle' in best_view_test[input_file.replace('bbox_image_val_100', 'bbox_image_val_cropped')]: 80 | guide_prompt = f"This is an image in '{stage}' stage. " 81 | 82 | questions = ['This picture shows the relationship between the pedestrian in the green bounding box and the ego-vehicle. Describe the pedestrian in the green bounding box or the pedestrian closest to the vehicle based on age, height, clothing, line of sight, relative position to the vehicle, movement status, weather conditions and road environment.', 'This picture shows the relationship between the ego-vehicle and the pedestrian in the green bounding box. Describe the ego-vehicle based on the relative position to the pedestrian, driving status, weather conditions and road environment. And describe the age, height, clothing of the pedestrian.'] 83 | 84 | questions[1] = guide_prompt + '\n' + questions[1] 85 | questions[0] = guide_prompt + '\n' + questions[0] 86 | else: 87 | guide_prompt = f"This is an image in '{stage}' stage. " 88 | 89 | questions = ['This picture shows the relationship between the pedestrian in the green bounding box and the vehicle in the blue bounding box. Describe the pedestrian in the green bounding box or the pedestrian closest to the vehicle based on age, height, clothing, line of sight, relative position to the vehicle, movement status, weather conditions and road environment.', 'This picture shows the relationship between the vehicle in the blue bounding box and the pedestrian in the green bounding box. Describe the vehicle in the blue bounding box or the vehicle closest to the pedestrian based on the relative position to the pedestrian, driving status, weather conditions and road environment. And describe the age, height, clothing of the pedestrian.'] 90 | 91 | questions[1] = guide_prompt + '\n' + questions[1] 92 | questions[0] = guide_prompt + '\n' + questions[0] 93 | 94 | 95 | keys = ['caption_pedestrian', 'caption_vehicle'] 96 | keys = keys[::-1] 97 | 98 | questions = questions[::-1] 99 | 100 | # previous_prompt = [ 101 | # "Describe the age, height and clothing of the pedestrian in the green box.", 102 | # "Describe the position of the pedestrian in the green box relative to the vehicle.", 103 | # "Describe the line of sight and movement status of the pedestrian in the green box.", 104 | # "Describe the weather conditions and road environment.", 105 | # "Describe the position of the vehicle in the blue box relative to the pedestrian in the green box.", 106 | # "Describe the driving status of the vehicle in the blue box." 107 | # ] 108 | 109 | # questions = previous_prompt + questions 110 | # print(questions, best_view_test[input_file.replace('bbox_image_test_new', 'bbox_image_test_cropped')]) 111 | 112 | res = dict() 113 | for k in range(len(questions)): 114 | inp = questions[k] 115 | 116 | if image is not None: 117 | # first message 118 | if model.config.mm_use_im_start_end: 119 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 120 | else: 121 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 122 | conv.append_message(conv.roles[0], inp) 123 | image = None 124 | else: 125 | # later messages 126 | conv.append_message(conv.roles[0], inp) 127 | conv.append_message(conv.roles[1], None) 128 | prompt = conv.get_prompt() 129 | 130 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 131 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 132 | keywords = [stop_str] 133 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 134 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 135 | 136 | with torch.inference_mode(): 137 | output_ids = model.generate( 138 | input_ids, 139 | images=image_tensor, 140 | image_sizes=[image_size], 141 | do_sample=True, 142 | temperature=0.2, 143 | # num_beams=3, 144 | max_new_tokens=512, 145 | streamer=streamer, 146 | stopping_criteria=[stopping_criteria], 147 | use_cache=True) 148 | 149 | outputs = tokenizer.decode(output_ids[0]).strip() 150 | conv.messages[-1][-1] = outputs 151 | # if k >= 6: 152 | # res[keys[k - 6]] = outputs.replace('<|startoftext|> ', '').replace('<|im_end|>', '') 153 | res[keys[k]] = outputs.replace('<|startoftext|> ', '').replace('<|im_end|>', '') 154 | 155 | return res 156 | 157 | 158 | # if __name__ == "__main__": 159 | # parser = argparse.ArgumentParser() 160 | # parser.add_argument("--model-path", type=str, default="/mnt/workspace/workgroup/chengxiang/work_dirs/llava1_6-34b-aicity-0312-lora") 161 | # parser.add_argument("--model-base", type=str, default="/mnt/workspace/workgroup/chengxiang/models/llava-v1.6-34b") 162 | # parser.add_argument("--image-file", type=str, default='/mnt/workspace/workgroup/chenghao/video_analysis/dataset/BDD_PC_5k/bbox_visualization/test/video57_4.jpg') 163 | # parser.add_argument("--device", type=str, default="cuda") 164 | # parser.add_argument("--conv-mode", type=str, default="chatml_direct") 165 | # parser.add_argument("--temperature", type=float, default=0.2) 166 | # parser.add_argument("--max-new-tokens", type=int, default=512) 167 | # parser.add_argument("--load-8bit", default=False) 168 | # parser.add_argument("--load-4bit", default=False) 169 | # parser.add_argument("--debug", default=True) 170 | # args = parser.parse_args() 171 | # main(args) 172 | -------------------------------------------------------------------------------- /data_preprocess/draw_bbox_on_frame.py: -------------------------------------------------------------------------------- 1 | from decord import VideoReader 2 | import cv2 3 | import numpy as np 4 | import json 5 | import os 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | import copy 9 | import argparse 10 | 11 | phase_number_map = { 12 | '0': 'prerecognition', 13 | '1': 'recognition', 14 | '2': 'judgement', 15 | '3': 'action', 16 | '4': 'avoidance' 17 | } 18 | 19 | 20 | def extract_frames(video_path, frame_indices, original_frame_indices): 21 | vr = VideoReader(video_path) 22 | if frame_indices[-1] == len(vr): 23 | frame_indices[-1] = len(vr) - 1 24 | frames = {ori_idx: vr[frame_idx].asnumpy() for frame_idx, ori_idx in zip(frame_indices, original_frame_indices)} 25 | return frames 26 | 27 | 28 | def draw_and_save_bboxes(key, frames, ped_bboxes, veh_bboxes, phase_numbers, phase_number_map): 29 | for frame_id, frame_np in frames.items(): 30 | frame = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) 31 | if str(frame_id) in ped_bboxes: 32 | bbox = ped_bboxes[str(frame_id)] 33 | xmin, ymin, width, height = bbox 34 | xmax = xmin + width 35 | ymax = ymin + height 36 | cv2.rectangle(frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color=(0, 255, 0), thickness=4) 37 | if str(frame_id) in veh_bboxes: 38 | bbox = veh_bboxes[str(frame_id)] 39 | xmin, ymin, width, height = bbox 40 | xmax = xmin + width 41 | ymax = ymin + height 42 | cv2.rectangle(frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color=(255, 0, 0), thickness=4) 43 | 44 | phase_number = phase_numbers.get(str(frame_id), "") 45 | if str(phase_number): 46 | if 'BDD' in key: 47 | file_name = key.replace('.mp4', f'_{phase_number_map[str(phase_number)]}.jpg').replace('/videos', '/bbox_global') 48 | dirname = os.path.dirname(file_name) 49 | os.makedirs(dirname, exist_ok=True) 50 | else: 51 | key = key.replace('.mp4', '/').replace('/videos', '/bbox_global') 52 | os.makedirs(key, exist_ok=True) 53 | file_name = f"{key}{phase_number}_{phase_number_map[str(phase_number)]}.jpg" 54 | 55 | cv2.imwrite(file_name, frame) 56 | 57 | 58 | def enlarge_bbox(bbox, scale=1.2): 59 | xmin, ymin, width, height = bbox 60 | center_x, center_y = xmin + width / 2, ymin + height / 2 61 | 62 | new_width = width * scale 63 | new_height = height * scale 64 | 65 | new_xmin = center_x - new_width / 2 66 | new_ymin = center_y - new_height / 2 67 | 68 | return new_xmin, new_ymin, new_width, new_height 69 | 70 | 71 | def enlarge_bbox_square(bbox, scale=1.2): 72 | xmin, ymin, width, height = bbox 73 | center_x, center_y = xmin + width / 2, ymin + height / 2 74 | 75 | new_width = width * scale 76 | new_height = height * scale 77 | 78 | new_height, new_width = max(new_width, new_height), max(new_width, new_height) # Not used when draw bbox 79 | 80 | new_xmin = center_x - new_width / 2 81 | new_ymin = center_y - new_height / 2 82 | 83 | return new_xmin, new_ymin, new_width, new_height 84 | 85 | 86 | def calculate_combined_bbox(bbox1, bbox2): 87 | xmin = min(bbox1[0], bbox2[0]) 88 | ymin = min(bbox1[1], bbox2[1]) 89 | xmax = max(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) 90 | ymax = max(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) 91 | 92 | return xmin, ymin, xmax - xmin, ymax - ymin 93 | 94 | 95 | def constrain_bbox_within_frame(bbox, frame_shape): 96 | xmin, ymin, xmax, ymax = bbox 97 | xmin = max(0, int(xmin)) 98 | ymin = max(0, int(ymin)) 99 | xmax = min(frame_shape[1], int(xmax)) 100 | ymax = min(frame_shape[0], int(ymax)) 101 | return xmin, ymin, xmax, ymax 102 | 103 | 104 | def draw_and_save_bboxes_scale_version(key, frames, ped_bboxes, veh_bboxes, phase_numbers, phase_number_map, scale=1.5): 105 | for frame_id, frame_np in frames.items(): 106 | frame = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) 107 | combined_bbox = None 108 | 109 | # Enlarge and draw pedestrian bbox 110 | if str(frame_id) in ped_bboxes: 111 | bbox = enlarge_bbox(ped_bboxes[str(frame_id)]) 112 | xmin, ymin, width, height = bbox 113 | xmax = xmin + width 114 | ymax = ymin + height 115 | xmin, ymin, xmax, ymax = constrain_bbox_within_frame((xmin, ymin, xmax, ymax), frame.shape) 116 | combined_bbox = (xmin, ymin, xmax - xmin, ymax - ymin) 117 | cv2.rectangle(frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color=(0, 255, 0), thickness=3) 118 | 119 | # Enlarge and draw vehicle bbox 120 | if str(frame_id) in veh_bboxes: 121 | bbox = enlarge_bbox(veh_bboxes[str(frame_id)]) 122 | xmin, ymin, width, height = bbox 123 | xmax = xmin + width 124 | ymax = ymin + height 125 | xmin, ymin, xmax, ymax = constrain_bbox_within_frame((xmin, ymin, xmax, ymax), frame.shape) 126 | if combined_bbox is not None: 127 | combined_bbox = calculate_combined_bbox(combined_bbox, (xmin, ymin, xmax - xmin, ymax - ymin)) 128 | else: 129 | combined_bbox = (xmin, ymin, xmax - xmin, ymax - ymin) 130 | cv2.rectangle(frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color=(255, 0, 0), thickness=3) 131 | 132 | # Enlarge the combined bbox 133 | if combined_bbox is not None: 134 | min_area = 0.1 135 | max_area = 0.6 136 | area_ratio = (combined_bbox[-2] * combined_bbox[-1]) / (frame.shape[0] * frame.shape[1]) 137 | try: 138 | if combined_bbox[-2] / combined_bbox[-1] > 4 or combined_bbox[-1] / combined_bbox[-2] > 4: 139 | width_ratio, height_ratio = combined_bbox[-2] / frame.shape[1], combined_bbox[-1] / frame.shape[0] 140 | area_ratio = max(width_ratio, height_ratio) 141 | except: 142 | print(f"[WARRNING]: Zero detected: {combined_bbox}") 143 | 144 | min_scale = 1.0 145 | max_scale = 3.0 146 | 147 | ratio = min(max_area, max(min_area, area_ratio)) 148 | # print(ratio) 149 | # scale = -4 * ratio + 3.4 150 | 151 | combined_bbox = enlarge_bbox_square(combined_bbox, scale=scale) 152 | xmin, ymin, width, height = combined_bbox 153 | xmax, ymax = int(xmin + width), int(ymin + height) 154 | xmin, ymin = int(xmin), int(ymin) 155 | xmin, ymin, xmax, ymax = constrain_bbox_within_frame((xmin, ymin, xmax, ymax), frame.shape) 156 | cropped_frame = frame[ymin:ymax, xmin:xmax] 157 | else: 158 | cropped_frame = frame 159 | 160 | 161 | # Get the corresponding phase number 162 | if str(frame_id) in phase_numbers: 163 | phase_number = phase_numbers[str(frame_id)] 164 | else: 165 | phase_number = '' 166 | 167 | if str(phase_number): 168 | if 'BDD' in key: 169 | file_name = key.replace('.mp4', f'_{phase_number_map[str(phase_number)]}.jpg').replace('/videos', '/bbox_local') 170 | dirname = os.path.dirname(file_name) 171 | os.makedirs(dirname, exist_ok=True) 172 | else: 173 | key = key.replace('.mp4', '/').replace('/videos', '/bbox_local') 174 | os.makedirs(key, exist_ok=True) 175 | file_name = f"{key}{phase_number}_{phase_number_map[str(phase_number)]}.jpg" 176 | 177 | if cropped_frame.size > 0: 178 | cv2.imwrite(file_name, cropped_frame) 179 | else: 180 | print(cropped_frame.shape) 181 | print(f"Empty cropped frame for frame ID {key} {frame_id} {ped_bboxes[str(frame_id)]} {combined_bbox}. Skipping save.") 182 | 183 | 184 | def process_video(args): 185 | video_path, data, phase_number_map, scale = args 186 | frame_indices = list(map(int, data["phase_number"].keys())) 187 | if len(frame_indices) == 0: 188 | return 189 | frame_indices_process = copy.deepcopy(frame_indices) 190 | if 'fps' in data: 191 | if float(data['fps']) > 40.0: 192 | for i in range(len(frame_indices)): 193 | frame_indices_process[i] = frame_indices_process[i] // 2 194 | frames = extract_frames(video_path, frame_indices_process, frame_indices) 195 | draw_and_save_bboxes( 196 | video_path, 197 | frames, 198 | data["ped_bboxes"], 199 | data["veh_bboxes"], 200 | data["phase_number"], 201 | phase_number_map, 202 | ) 203 | draw_and_save_bboxes_scale_version( 204 | video_path, 205 | frames, 206 | data["ped_bboxes"], 207 | data["veh_bboxes"], 208 | data["phase_number"], 209 | phase_number_map, 210 | scale 211 | ) 212 | 213 | 214 | if __name__ == '__main__': 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument('--anno', type=str, help='File with bbox anno') 217 | parser.add_argument('--worker', type=int, default=1, help='process num (CPU count)') 218 | parser.add_argument('--scale', type=float, default=1.5, help='scale up coefficient') 219 | args = parser.parse_args() 220 | anno = json.load(open(args.anno)) 221 | num_processes = args.worker 222 | with Pool(processes=num_processes) as pool: 223 | jobs = [] 224 | for video_path, data in tqdm(anno.items()): 225 | job = (video_path, data, phase_number_map, args.scale) 226 | jobs.append(job) 227 | results = list(tqdm(pool.imap(process_video, jobs), total=len(jobs))) -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", zero_stage=2, use_flash_attn2=False, **kwargs): 27 | kwargs = {"device_map": device, "use_safetensors": True, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn2: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | #import pdb; pdb.set_trace() 49 | 50 | if 'llava' in model_name.lower(): 51 | # Load LLaVA model 52 | if 'lora' in model_name.lower() and model_base is None: 53 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 54 | if 'lora' in model_name.lower() and model_base is not None: 55 | 56 | from llava.model.language_model.llava_llama import LlavaConfig 57 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) 58 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 59 | print('Loading LLaVA from base model...') 60 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 61 | 62 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 63 | if model.lm_head.weight.shape[0] != token_num: 64 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 65 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 66 | 67 | print('Loading additional LLaVA weights...') 68 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 69 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 70 | else: 71 | # this is probably from HF Hub 72 | from huggingface_hub import hf_hub_download 73 | def load_from_hf(repo_id, filename, subfolder=None): 74 | cache_file = hf_hub_download( 75 | repo_id=repo_id, 76 | filename=filename, 77 | subfolder=subfolder) 78 | return torch.load(cache_file, map_location='cpu') 79 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 80 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 81 | if any(k.startswith('model.model.') for k in non_lora_trainables): 82 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 83 | model.load_state_dict(non_lora_trainables, strict=False) 84 | 85 | from peft import PeftModel 86 | print('Loading LoRA weights...') 87 | model = PeftModel.from_pretrained(model, model_path) 88 | 89 | print('Merging LoRA weights...') 90 | model = model.merge_and_unload() 91 | print('Model is loaded...') 92 | elif model_base is not None: 93 | # this may be mm projector only 94 | print('Loading LLaVA from base model...') 95 | if 'mpt' in model_name.lower(): 96 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 97 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 98 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 99 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 100 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 101 | else: 102 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 103 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 104 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 105 | 106 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 107 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 108 | model.load_state_dict(mm_projector_weights, strict=False) 109 | else: 110 | if 'mpt' in model_name.lower(): 111 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 112 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 113 | elif 'mistral' in model_name.lower(): 114 | tokenizer = AutoTokenizer.from_pretrained(model_path) 115 | model = LlavaMistralForCausalLM.from_pretrained( 116 | model_path, 117 | low_cpu_mem_usage=True, 118 | **kwargs 119 | ) 120 | else: 121 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 122 | if zero_stage > 2: 123 | del kwargs['device_map'] 124 | del kwargs['use_safetensors'] 125 | print(kwargs) 126 | model = LlavaLlamaForCausalLM.from_pretrained( 127 | model_path, **kwargs 128 | ) 129 | else: 130 | model = LlavaLlamaForCausalLM.from_pretrained( 131 | model_path, 132 | low_cpu_mem_usage=True, 133 | **kwargs 134 | ) 135 | else: 136 | # Load language model 137 | if model_base is not None: 138 | # PEFT model 139 | from peft import PeftModel 140 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 141 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 142 | print(f"Loading LoRA weights from {model_path}") 143 | model = PeftModel.from_pretrained(model, model_path) 144 | print(f"Merging weights") 145 | model = model.merge_and_unload() 146 | print('Convert to FP16...') 147 | model.to(torch.float16) 148 | else: 149 | use_fast = False 150 | if 'mpt' in model_name.lower(): 151 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 152 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 153 | else: 154 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 155 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 156 | 157 | image_processor = None 158 | 159 | if 'llava' in model_name.lower(): 160 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 161 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 162 | if mm_use_im_patch_token: 163 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 164 | if mm_use_im_start_end: 165 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 166 | model.resize_token_embeddings(len(tokenizer)) 167 | 168 | vision_tower = model.get_vision_tower() 169 | if not vision_tower.is_loaded: 170 | vision_tower.load_model(device_map=device_map) 171 | if device_map != 'auto': 172 | vision_tower.to(device=device_map, dtype=torch.float16) 173 | image_processor = vision_tower.image_processor 174 | 175 | if hasattr(model.config, "max_sequence_length"): 176 | context_len = model.config.max_sequence_length 177 | else: 178 | context_len = 2048 179 | 180 | return tokenizer, model, image_processor, context_len 181 | -------------------------------------------------------------------------------- /data_preprocess/transform_llava_format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--root', type=str, default='/mnt/data/AICITY2024/', help='data root path') 8 | parser.add_argument('--split', type=str, default='train') 9 | parser.add_argument('--save-folder', type=str, default='processed_anno', help='dirname for saving json file') 10 | parser.add_argument('--wts-global-image-path', type=str, required=True, help='root path for wts global images') 11 | parser.add_argument('--bdd-global-image-path', type=str, required=True, help='root path for bdd global images') 12 | 13 | args = parser.parse_args() 14 | 15 | root = args.root 16 | 17 | phrase_number_map = { 18 | '0': 'prerecognition', 19 | '1': 'recognition', 20 | '2': 'judgement', 21 | '3': 'action', 22 | '4': 'avoidance' 23 | } 24 | number_phrase_map = {v: k for k, v in phrase_number_map.items()} 25 | 26 | camera_path_mapping = dict() 27 | 28 | wts_anno_path = os.path.join(root, 'WTS/annotations/caption', args.split) 29 | bdd_anno_path = os.path.join(root, 'BDD_PC_5k/annotations/caption', args.split) 30 | 31 | train_samples = list() 32 | 33 | overhead = 'overhead_view' 34 | vehicle = 'vehicle_view' 35 | 36 | for item in os.listdir(wts_anno_path): 37 | overhead_flag, vehicle_flag = True, True 38 | try: 39 | overhead_view = json.load(open(f'{wts_anno_path}/{item}/{overhead}/{item}_caption.json')) 40 | except: 41 | overhead_flag = False 42 | try: 43 | vehicle_view = json.load(open(f'{wts_anno_path}/{item}/{vehicle}/{item}_caption.json')) 44 | except: 45 | vehicle_flag = False 46 | sample_id = item 47 | 48 | if overhead_flag: 49 | for event in overhead_view['event_phase']: 50 | cur_data = dict() 51 | cur_data['id'] = sample_id 52 | cur_data['segment'] = phrase_number_map[event['labels'][0]] 53 | cur_data['view'] = 'overhead' 54 | cur_data['start_time'] = event['start_time'] 55 | cur_data['end_time'] = event['end_time'] 56 | cur_data['conversations'] = list() 57 | 58 | cur_data['conversations'].append({ 59 | 'from': 'human', 60 | 'value': '\nPlease describe the interested pedestrian in the video.' 61 | }) 62 | 63 | cur_data['conversations'].append({ 64 | 'from': 'gpt', 65 | 'value': event['caption_pedestrian'] 66 | }) 67 | 68 | cur_data['conversations'].append({ 69 | 'from': 'human', 70 | 'value': 'Please describe the interested vehicle in the video.' 71 | }) 72 | 73 | cur_data['conversations'].append({ 74 | 'from': 'gpt', 75 | 'value': event['caption_vehicle'] 76 | }) 77 | 78 | for image in overhead_view['overhead_videos']: 79 | cur_data['image'] = image 80 | train_samples.append(copy.deepcopy(cur_data)) 81 | 82 | if vehicle_flag: 83 | for event in vehicle_view['event_phase']: 84 | cur_data = dict() 85 | cur_data['id'] = sample_id 86 | cur_data['segment'] = phrase_number_map[event['labels'][0]] 87 | cur_data['view'] = 'vehicle' 88 | cur_data['start_time'] = event['start_time'] 89 | cur_data['end_time'] = event['end_time'] 90 | cur_data['conversations'] = list() 91 | 92 | cur_data['conversations'].append({ 93 | 'from': 'human', 94 | 'value': '\nPlease describe the interested pedestrian in the video.' 95 | }) 96 | 97 | cur_data['conversations'].append({ 98 | 'from': 'gpt', 99 | 'value': event['caption_pedestrian'] 100 | }) 101 | 102 | cur_data['conversations'].append({ 103 | 'from': 'human', 104 | 'value': 'Please describe the interested vehicle in the video.' 105 | }) 106 | 107 | cur_data['conversations'].append({ 108 | 'from': 'gpt', 109 | 'value': event['caption_vehicle'] 110 | }) 111 | 112 | cur_data['image'] = vehicle_view['vehicle_view'] 113 | train_samples.append(cur_data) 114 | 115 | for item in os.listdir(f'{wts_anno_path}/normal_trimmed'): 116 | overhead_flag, vehicle_flag = True, True 117 | try: 118 | overhead_view = json.load(open(f'{wts_anno_path}/normal_trimmed/{item}/{overhead}/{item}_caption.json')) 119 | except: 120 | overhead_flag = False 121 | try: 122 | vehicle_view = json.load(open(f'{wts_anno_path}/normal_trimmed/{item}/{vehicle}/{item}_caption.json')) 123 | except: 124 | vehicle_flag = False 125 | sample_id = item 126 | 127 | if overhead_flag: 128 | for event in overhead_view['event_phase']: 129 | cur_data = dict() 130 | cur_data['id'] = sample_id 131 | cur_data['segment'] = phrase_number_map[event['labels'][0]] 132 | cur_data['view'] = 'overhead' 133 | cur_data['start_time'] = event['start_time'] 134 | cur_data['end_time'] = event['end_time'] 135 | cur_data['conversations'] = list() 136 | 137 | cur_data['conversations'].append({ 138 | 'from': 'human', 139 | 'value': '\nPlease describe the interested pedestrian in the video.' 140 | }) 141 | 142 | cur_data['conversations'].append({ 143 | 'from': 'gpt', 144 | 'value': event['caption_pedestrian'] 145 | }) 146 | 147 | cur_data['conversations'].append({ 148 | 'from': 'human', 149 | 'value': 'Please describe the interested vehicle in the video.' 150 | }) 151 | 152 | cur_data['conversations'].append({ 153 | 'from': 'gpt', 154 | 'value': event['caption_vehicle'] 155 | }) 156 | 157 | for image in overhead_view['overhead_videos']: 158 | cur_data['image'] = image 159 | train_samples.append(copy.deepcopy(cur_data)) 160 | 161 | if vehicle_flag: 162 | for event in vehicle_view['event_phase']: 163 | cur_data = dict() 164 | cur_data['id'] = sample_id 165 | cur_data['segment'] = phrase_number_map[event['labels'][0]] 166 | cur_data['view'] = 'vehicle' 167 | cur_data['start_time'] = event['start_time'] 168 | cur_data['end_time'] = event['end_time'] 169 | cur_data['conversations'] = list() 170 | 171 | cur_data['conversations'].append({ 172 | 'from': 'human', 173 | 'value': '\nPlease describe the interested pedestrian in the video.' 174 | }) 175 | 176 | cur_data['conversations'].append({ 177 | 'from': 'gpt', 178 | 'value': event['caption_pedestrian'] 179 | }) 180 | 181 | cur_data['conversations'].append({ 182 | 'from': 'human', 183 | 'value': 'Please describe the interested vehicle in the video.' 184 | }) 185 | 186 | cur_data['conversations'].append({ 187 | 'from': 'gpt', 188 | 'value': event['caption_vehicle'] 189 | }) 190 | 191 | cur_data['image'] = vehicle_view['vehicle_view'] 192 | train_samples.append(cur_data) 193 | 194 | for item in os.listdir(bdd_anno_path): 195 | captions = json.load(open(f'{bdd_anno_path}/{item}')) 196 | sample_id = captions['id'] 197 | for event in captions['event_phase']: 198 | cur_data = dict() 199 | cur_data['id'] = sample_id 200 | cur_data['segment'] = event['labels'][0] 201 | cur_data['start_time'] = event['start_time'] 202 | cur_data['end_time'] = event['end_time'] 203 | cur_data['conversations'] = list() 204 | 205 | cur_data['conversations'].append({ 206 | 'from': 'human', 207 | 'value': '\nPlease describe the interested pedestrian in the video.' 208 | }) 209 | 210 | cur_data['conversations'].append({ 211 | 'from': 'gpt', 212 | 'value': event['caption_pedestrian'] 213 | }) 214 | 215 | cur_data['conversations'].append({ 216 | 'from': 'human', 217 | 'value': 'Please describe the interested vehicle in the video.' 218 | }) 219 | 220 | cur_data['conversations'].append({ 221 | 'from': 'gpt', 222 | 'value': event['caption_vehicle'] 223 | }) 224 | 225 | cur_data['image'] = captions['video_name'] 226 | camera_path_mapping[cur_data['image'].replace('.mp4', '')] = os.path.join(args.bdd_global_image_path, args.split) 227 | train_samples.append(cur_data) 228 | 229 | 230 | global_image_path = os.path.join(args.wts_global_image_path, args.split) 231 | for event in os.listdir(global_image_path): 232 | if 'normal_trimmed' in event: 233 | continue 234 | for view in os.listdir(os.path.join(global_image_path, event)): 235 | parent_path = os.path.join(global_image_path, event, view) 236 | for camera in os.listdir(parent_path): 237 | camera_path_mapping[camera] = os.path.join(parent_path, camera) 238 | 239 | 240 | for event in os.listdir(os.path.join(global_image_path, 'normal_trimmed')): 241 | for view in os.listdir(os.path.join(global_image_path, 'normal_trimmed', event)): 242 | parent_path = os.path.join(global_image_path, 'normal_trimmed', event, view) 243 | for camera in os.listdir(parent_path): 244 | camera_path_mapping[camera] = os.path.join(parent_path, camera) 245 | 246 | 247 | reserved_train_samples = list() 248 | for item in train_samples: 249 | image = item['image'].replace('.mp4', '') 250 | segment = item['segment'] 251 | 252 | if 'video' in image: 253 | train_image_name = f'{image}_{segment}.jpg' 254 | else: 255 | train_image_name = f'{number_phrase_map[segment]}_{segment}.jpg' 256 | 257 | if image in camera_path_mapping: 258 | item['image'] = os.path.join(camera_path_mapping[image], train_image_name) 259 | if os.path.exists(item['image']): 260 | item['image'] = item['image'].replace('./data/', '') 261 | reserved_train_samples.append(item) 262 | 263 | 264 | os.makedirs(args.save_folder, exist_ok=True) 265 | with open(os.path.join(args.save_folder, f'wts_bdd_{args.split}.json'), 'w+') as f: 266 | f.write(json.dumps(reserved_train_samples, indent=4)) -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | def process_anyres_image_aicity(image, processor, local_image): 148 | """ 149 | Process an image with variable resolutions. 150 | 151 | Args: 152 | image (PIL.Image.Image): The input image to be processed. 153 | processor: The image processor object. 154 | grid_pinpoints (str): A string representation of a list of possible resolutions. 155 | 156 | Returns: 157 | torch.Tensor: A tensor containing the processed image patches. 158 | """ 159 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 160 | local_image_resize = local_image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 161 | 162 | # image_patches = [image_original_resize] + [local_image_resize] 163 | image_patches = [local_image_resize] + [image_original_resize] 164 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 165 | for image_patch in image_patches] 166 | return torch.stack(image_patches, dim=0) 167 | 168 | 169 | def load_image_from_base64(image): 170 | return Image.open(BytesIO(base64.b64decode(image))) 171 | 172 | 173 | def expand2square(pil_img, background_color): 174 | width, height = pil_img.size 175 | if width == height: 176 | return pil_img 177 | elif width > height: 178 | result = Image.new(pil_img.mode, (width, width), background_color) 179 | result.paste(pil_img, (0, (width - height) // 2)) 180 | return result 181 | else: 182 | result = Image.new(pil_img.mode, (height, height), background_color) 183 | result.paste(pil_img, ((height - width) // 2, 0)) 184 | return result 185 | 186 | 187 | def process_images(images, image_processor, model_cfg, local_image = None): 188 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 189 | new_images = [] 190 | if image_aspect_ratio == 'pad': 191 | for image in images: 192 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 193 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 194 | new_images.append(image) 195 | elif image_aspect_ratio == "anyres": 196 | for image in images: 197 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 198 | new_images.append(image) 199 | elif image_aspect_ratio == "bigsmall": 200 | for i, image in enumerate(images): 201 | image = process_anyres_image_aicity(image, image_processor, local_image[i]) 202 | new_images.append(image) 203 | else: 204 | return image_processor(images, return_tensors='pt')['pixel_values'] 205 | if all(x.shape == new_images[0].shape for x in new_images): 206 | new_images = torch.stack(new_images, dim=0) 207 | return new_images 208 | 209 | 210 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 211 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 212 | 213 | def insert_separator(X, sep): 214 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 215 | 216 | input_ids = [] 217 | offset = 0 218 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 219 | offset = 1 220 | input_ids.append(prompt_chunks[0][0]) 221 | 222 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 223 | input_ids.extend(x[offset:]) 224 | 225 | if return_tensors is not None: 226 | if return_tensors == 'pt': 227 | return torch.tensor(input_ids, dtype=torch.long) 228 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 229 | return input_ids 230 | 231 | 232 | def get_model_name_from_path(model_path): 233 | model_path = model_path.strip("/") 234 | model_paths = model_path.split("/") 235 | if model_paths[-1].startswith('checkpoint-'): 236 | return model_paths[-2] + "_" + model_paths[-1] 237 | else: 238 | return model_paths[-1] 239 | 240 | class KeywordsStoppingCriteria(StoppingCriteria): 241 | def __init__(self, keywords, tokenizer, input_ids): 242 | self.keywords = keywords 243 | self.keyword_ids = [] 244 | self.max_keyword_len = 0 245 | for keyword in keywords: 246 | cur_keyword_ids = tokenizer(keyword).input_ids 247 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 248 | cur_keyword_ids = cur_keyword_ids[1:] 249 | if len(cur_keyword_ids) > self.max_keyword_len: 250 | self.max_keyword_len = len(cur_keyword_ids) 251 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 252 | self.tokenizer = tokenizer 253 | self.start_len = input_ids.shape[1] 254 | 255 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 256 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 257 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 258 | for keyword_id in self.keyword_ids: 259 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 260 | if torch.equal(truncated_output_ids, keyword_id): 261 | return True 262 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 263 | for keyword in self.keywords: 264 | if keyword in outputs: 265 | return True 266 | return False 267 | 268 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 269 | outputs = [] 270 | for i in range(output_ids.shape[0]): 271 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 272 | return all(outputs) 273 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /data_preprocess/shortQA_merge.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tqdm 3 | import copy 4 | 5 | # Short QA Construction 6 | # 2)the sentences within the same dimension are concatenated to form a cohesive segment. 7 | # 3)for short QA, some filtering and sampling operations will be carried out to improve data quality and distribution. 8 | # 4) use high-quality prompt after "3.2.3 Textual Prompt Engineering" to construct fullQA dataset. 9 | 10 | def load_jsonl(filename): 11 | with open(filename, "r", encoding="utf-8") as f: 12 | return [json.loads(l.strip("\n")) for l in f.readlines()] 13 | 14 | def shortQA_merge_pedestrian(): 15 | input_file_list = ["./processed_anno/caption_split/caption_split_pedestrian_0.json" 16 | ] 17 | save_file = "./processed_anno/caption_split/caption_split_pedestrian_merge.json" 18 | question_list = [ 19 | "Describe the age, height and clothing of the pedestrian in the green box.", 20 | "Describe the position of the pedestrian in the green box relative to the vehicle.", 21 | "Describe the line of sight and movement status of the pedestrian in the green box.", 22 | "Describe the weather conditions and road environment." 23 | ] 24 | 25 | result = [] 26 | 27 | for input_file in input_file_list: 28 | data_list = load_jsonl(input_file) 29 | for data in data_list: 30 | id = data['id'] 31 | conversations = data['conversations'] 32 | pedestrian_caption_ori = conversations[1]['value'] 33 | 34 | # 描述拆分 35 | pedestrian_caption = pedestrian_caption_ori + ' ' 36 | pedestrian_caption_list = pedestrian_caption.split('. ') 37 | pedestrian_caption_list = [sentence.strip() + '.' for sentence in pedestrian_caption_list if sentence] 38 | 39 | pedestrian_res = [[], [], [], []] 40 | pedestrian_response = data['pedestrian_response'] 41 | pedestrian_response_split = pedestrian_response.split('\n') 42 | for pedestrian_response_data in pedestrian_response_split: 43 | pedestrian_num, pedestrian_class = pedestrian_response_data.split('.') 44 | pedestrian_num = int(pedestrian_num) - 1 45 | pedestrian_class = pedestrian_class.lower() 46 | if pedestrian_num >= len(pedestrian_caption_list): 47 | continue 48 | if 'a' in pedestrian_class: 49 | pedestrian_res[0].append(pedestrian_caption_list[pedestrian_num]) 50 | elif 'b' in pedestrian_class: 51 | pedestrian_res[1].append(pedestrian_caption_list[pedestrian_num]) 52 | elif 'c' in pedestrian_class: 53 | pedestrian_res[2].append(pedestrian_caption_list[pedestrian_num]) 54 | elif 'd' in pedestrian_class: 55 | pedestrian_res[3].append(pedestrian_caption_list[pedestrian_num]) 56 | 57 | for i in range(len(pedestrian_res)): 58 | if len(pedestrian_res[i]) == 0: 59 | continue 60 | conversations_new = [] 61 | pedestrian_res[i] = ' '.join(pedestrian_res[i]) 62 | conversations_new.append({ 63 | "from": "human", 64 | "value": "\n" + question_list[i] 65 | }) 66 | conversations_new.append({ 67 | "from": "gpt", 68 | "value": pedestrian_res[i] 69 | }) 70 | data['conversations'] = conversations_new 71 | data['full_flag'] = 0 72 | data_new = copy.deepcopy(data) 73 | result.append(data_new) 74 | 75 | conversations_new = [] 76 | conversations_new.append({ 77 | "from": "human", 78 | "value": "\nThis picture shows the relationship between the pedestrian in the green box and the vehicle in the blue box. Describe the pedestrian in the green box or the pedestrian closest to the vehicle based on age, height, clothing, line of sight, relative position to the vehicle, movement status, weather conditions and road environment." 79 | }) 80 | conversations_new.append({ 81 | "from": "gpt", 82 | "value": pedestrian_caption_ori 83 | }) 84 | data['conversations'] = conversations_new 85 | data['full_flag'] = 1 86 | data_new = copy.deepcopy(data) 87 | result.append(data_new) 88 | print("len of shortQA_merge_prdestrian data is ", len(result)) 89 | with open(save_file, "w", encoding="utf-8") as f: 90 | json.dump(result, f, ensure_ascii=False, indent=4) 91 | 92 | def shortQA_merge_vehicle(): 93 | input_file_list = ["./processed_anno/caption_split/caption_split_vehicle_0.json"] 94 | save_file = "./processed_anno/caption_split/caption_split_vehicle_merge.json" 95 | question_list = [ 96 | "Describe the position of the vehicle in the blue box relative to the pedestrian in the green box.", 97 | "Describe the driving status of the vehicle in the blue box." 98 | "Describe the attributes of the pedestrian with the green box.", # not used 99 | "Describe the weather conditions and road environment." # not used 100 | ] 101 | 102 | result = [] 103 | 104 | for input_file in input_file_list: 105 | data_list = load_jsonl(input_file) 106 | for data in data_list: 107 | id = data['id'] 108 | conversations = data['conversations'] 109 | vehicle_caption_ori = conversations[3]['value'] 110 | 111 | # split description 112 | vehicle_caption = vehicle_caption_ori + ' ' 113 | pedestrian_caption_list = vehicle_caption.split('. ') 114 | pedestrian_caption_list = [sentence.strip() + '.' for sentence in pedestrian_caption_list if sentence] 115 | 116 | pedestrian_res = [[], [], [], []] 117 | pedestrian_response = data['vehicle_response'] 118 | pedestrian_response_split = pedestrian_response.split('\n') 119 | for pedestrian_response_data in pedestrian_response_split: 120 | pedestrian_num, pedestrian_class = pedestrian_response_data.split('.') 121 | pedestrian_num = int(pedestrian_num) - 1 122 | pedestrian_class = pedestrian_class.lower() 123 | if pedestrian_num >= len(pedestrian_caption_list): 124 | # print(pedestrian_response) 125 | # print(pedestrian_caption_list) 126 | continue 127 | if 'a' in pedestrian_class: 128 | pedestrian_res[0].append(pedestrian_caption_list[pedestrian_num]) 129 | elif 'b' in pedestrian_class: 130 | pedestrian_res[1].append(pedestrian_caption_list[pedestrian_num]) 131 | 132 | for i in range(len(pedestrian_res)): 133 | if len(pedestrian_res[i]) == 0: 134 | continue 135 | conversations_new = [] 136 | pedestrian_res[i] = ' '.join(pedestrian_res[i]) 137 | conversations_new.append({ 138 | "from": "human", 139 | "value": "\n" + question_list[i] 140 | }) 141 | conversations_new.append({ 142 | "from": "gpt", 143 | "value": pedestrian_res[i] 144 | }) 145 | data['conversations'] = conversations_new 146 | data['full_flag'] = 0 147 | data_new = copy.deepcopy(data) 148 | result.append(data_new) 149 | 150 | # print(conversations_new) 151 | conversations_new = [] 152 | conversations_new.append({ 153 | "from": "human", 154 | "value": "\nThis picture shows the relationship between the vehicle in the blue box and the pedestrian in the green box. Describe the vehicle in the blue box or the vehicle closest to the pedestrian based on the relative position to the pedestrian, driving status, weather conditions and road environment. And describe the age, height, clothing of the pedestrian." 155 | }) 156 | conversations_new.append({ 157 | "from": "gpt", 158 | "value": vehicle_caption_ori 159 | }) 160 | data['conversations'] = conversations_new 161 | data['full_flag'] = 1 162 | data_new = copy.deepcopy(data) 163 | result.append(data_new) 164 | print("len of shortQA_merge_vehicle data is ", len(result)) 165 | with open(save_file, "w", encoding="utf-8") as f: 166 | json.dump(result, f, ensure_ascii=False, indent=4) 167 | 168 | def shortQA_merge(): 169 | input_file_list = ["./processed_anno/caption_split/caption_split_pedestrian_merge.json", 170 | "./processed_anno/caption_split/caption_split_vehicle_merge.json"] 171 | save_file = "./processed_anno/caption_split/caption_split_merge.json" 172 | input_json_merge = [] 173 | for input_file in input_file_list: 174 | with open(input_file, "r", encoding="utf-8") as f: 175 | input_json = json.load(f) 176 | input_json_merge.extend(input_json) 177 | with open(save_file, "w", encoding="utf-8") as f: 178 | json.dump(input_json_merge, f, ensure_ascii=False, indent=4) 179 | 180 | 181 | 182 | # Split short QA and long QA data, filter and sample for short QA 183 | def data_filter(): 184 | input_file = "./processed_anno/caption_split/caption_split_merge.json" 185 | save_file = './processed_anno/llava_format/wts_bdd_llava_qa_train.json' 186 | 187 | result_single_question = [] 188 | result_full_question = [] 189 | with open(input_file, "r", encoding="utf-8") as f: 190 | input_json = json.load(f) 191 | for data in input_json: 192 | if data["full_flag"] == 1: 193 | result_full_question.append(data) 194 | continue 195 | else: 196 | # if there is no pedestrian detection box in the video, filter the QA description of pedestrian 197 | if "P" not in data["tag"]: 198 | question = data["conversations"][0]["value"] 199 | if "pedestrian" in question: 200 | continue 201 | # For monitoring perspectives, if there is no vehicle detection box in the video, filter the QA description of the vehicle 202 | if 'BDD_PC_5k' not in data['image'] and 'vehicle_view' not in data['image']: 203 | if "V" not in data["tag"]: 204 | question = data["conversations"][0]["value"] 205 | if "vehicle" in question: 206 | continue 207 | result_single_question.append(data) 208 | print("len of result_single_question is ", len(result_single_question)) 209 | print("len of result_full_question is ", len(result_full_question)) 210 | 211 | import random 212 | result_single_question = random.sample(result_single_question, len(result_full_question)) 213 | 214 | result_full_question.extend(result_single_question) 215 | print("len of final_question is ", len(result_full_question)) 216 | 217 | with open(save_file, "w", encoding="utf-8") as f: 218 | json.dump(result_full_question, f, ensure_ascii=False, indent=4) 219 | 220 | # not used 221 | # save_file = input_file.replace('.json', '_shortQA.json') 222 | # with open(save_file, "w", encoding="utf-8") as f: 223 | # json.dump(result_single_question, f, ensure_ascii=False, indent=4) 224 | # 225 | # save_file = input_file.replace('.json', '_fullQA.json') 226 | # with open(save_file, "w", encoding="utf-8") as f: 227 | # json.dump(result_full_question, f, ensure_ascii=False, indent=4) 228 | 229 | 230 | if __name__ == '__main__': 231 | shortQA_merge_pedestrian() 232 | shortQA_merge_vehicle() 233 | shortQA_merge() 234 | data_filter() -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.utils.data import Sampler 6 | 7 | from transformers import Trainer 8 | from transformers.trainer import ( 9 | is_sagemaker_mp_enabled, 10 | get_parameter_names, 11 | has_length, 12 | ALL_LAYERNORM_LAYERS, 13 | logger, 14 | ) 15 | from typing import List, Optional 16 | 17 | 18 | def maybe_zero_3(param, ignore_status=False, name=None): 19 | from deepspeed import zero 20 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 21 | if hasattr(param, "ds_id"): 22 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 23 | if not ignore_status: 24 | print(name, 'no ignore status') 25 | with zero.GatheredParameters([param]): 26 | param = param.data.detach().cpu().clone() 27 | else: 28 | param = param.detach().cpu().clone() 29 | return param 30 | 31 | 32 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 33 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 34 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 35 | return to_return 36 | 37 | 38 | def split_to_even_chunks(indices, lengths, num_chunks): 39 | """ 40 | Split a list of indices into `chunks` chunks of roughly equal lengths. 41 | """ 42 | 43 | if len(indices) % num_chunks != 0: 44 | return [indices[i::num_chunks] for i in range(num_chunks)] 45 | 46 | num_indices_per_chunk = len(indices) // num_chunks 47 | 48 | chunks = [[] for _ in range(num_chunks)] 49 | chunks_lengths = [0 for _ in range(num_chunks)] 50 | for index in indices: 51 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 52 | chunks[shortest_chunk].append(index) 53 | chunks_lengths[shortest_chunk] += lengths[index] 54 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 55 | chunks_lengths[shortest_chunk] = float("inf") 56 | 57 | return chunks 58 | 59 | 60 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 61 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 62 | assert all(l != 0 for l in lengths), "Should not have zero length." 63 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 64 | # all samples are in the same modality 65 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 66 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 67 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 68 | 69 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 70 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 71 | megabatch_size = world_size * batch_size 72 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 73 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 74 | 75 | last_mm = mm_megabatches[-1] 76 | last_lang = lang_megabatches[-1] 77 | additional_batch = last_mm + last_lang 78 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 79 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 80 | megabatches = [megabatches[i] for i in megabatch_indices] 81 | 82 | if len(additional_batch) > 0: 83 | megabatches.append(sorted(additional_batch)) 84 | 85 | return [i for megabatch in megabatches for i in megabatch] 86 | 87 | 88 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 89 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 90 | indices = torch.randperm(len(lengths), generator=generator) 91 | megabatch_size = world_size * batch_size 92 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 93 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 94 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 95 | 96 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 97 | 98 | 99 | class LengthGroupedSampler(Sampler): 100 | r""" 101 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 102 | keeping a bit of randomness. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | batch_size: int, 108 | world_size: int, 109 | lengths: Optional[List[int]] = None, 110 | generator=None, 111 | group_by_modality: bool = False, 112 | ): 113 | if lengths is None: 114 | raise ValueError("Lengths must be provided.") 115 | 116 | self.batch_size = batch_size 117 | self.world_size = world_size 118 | self.lengths = lengths 119 | self.generator = generator 120 | self.group_by_modality = group_by_modality 121 | 122 | def __len__(self): 123 | return len(self.lengths) 124 | 125 | def __iter__(self): 126 | if self.group_by_modality: 127 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 128 | else: 129 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 130 | return iter(indices) 131 | 132 | 133 | class LLaVATrainer(Trainer): 134 | 135 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 136 | if self.train_dataset is None or not has_length(self.train_dataset): 137 | return None 138 | 139 | if self.args.group_by_modality_length: 140 | lengths = self.train_dataset.modality_lengths 141 | return LengthGroupedSampler( 142 | self.args.train_batch_size, 143 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 144 | lengths=lengths, 145 | group_by_modality=True, 146 | ) 147 | else: 148 | return super()._get_train_sampler() 149 | 150 | def create_optimizer(self): 151 | """ 152 | Setup the optimizer. 153 | 154 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 155 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 156 | """ 157 | if is_sagemaker_mp_enabled(): 158 | return super().create_optimizer() 159 | 160 | opt_model = self.model 161 | 162 | if self.optimizer is None: 163 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 164 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 165 | if self.args.mm_projector_lr is not None: 166 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 167 | optimizer_grouped_parameters = [ 168 | { 169 | "params": [ 170 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 171 | ], 172 | "weight_decay": self.args.weight_decay, 173 | }, 174 | { 175 | "params": [ 176 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 177 | ], 178 | "weight_decay": 0.0, 179 | }, 180 | { 181 | "params": [ 182 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 183 | ], 184 | "weight_decay": self.args.weight_decay, 185 | "lr": self.args.mm_projector_lr, 186 | }, 187 | { 188 | "params": [ 189 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 190 | ], 191 | "weight_decay": 0.0, 192 | "lr": self.args.mm_projector_lr, 193 | }, 194 | ] 195 | else: 196 | optimizer_grouped_parameters = [ 197 | { 198 | "params": [ 199 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 200 | ], 201 | "weight_decay": self.args.weight_decay, 202 | }, 203 | { 204 | "params": [ 205 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 206 | ], 207 | "weight_decay": 0.0, 208 | }, 209 | ] 210 | 211 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 212 | 213 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 214 | if optimizer_cls.__name__ == "Adam8bit": 215 | import bitsandbytes 216 | 217 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 218 | 219 | skipped = 0 220 | for module in opt_model.modules(): 221 | if isinstance(module, nn.Embedding): 222 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 223 | logger.info(f"skipped {module}: {skipped/2**20}M params") 224 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 225 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 226 | logger.info(f"skipped: {skipped/2**20}M params") 227 | 228 | return self.optimizer 229 | 230 | def _save_checkpoint(self, model, trial, metrics=None): 231 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 232 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 233 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 234 | 235 | run_dir = self._get_output_dir(trial=trial) 236 | output_dir = os.path.join(run_dir, checkpoint_folder) 237 | 238 | # Only save Adapter 239 | keys_to_match = ['mm_projector', 'vision_resampler'] 240 | if getattr(self.args, "use_im_start_end", False): 241 | keys_to_match.extend(['embed_tokens', 'embed_in']) 242 | 243 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 244 | 245 | if self.args.local_rank == 0 or self.args.local_rank == -1: 246 | self.model.config.save_pretrained(output_dir) 247 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 248 | else: 249 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 250 | 251 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 252 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 253 | pass 254 | else: 255 | super(LLaVATrainer, self)._save(output_dir, state_dict) 256 | --------------------------------------------------------------------------------