├── README.md ├── assets └── framework.png ├── dataset ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── base_dataset.cpython-310.pyc │ ├── it_dataset.cpython-310.pyc │ ├── utils.cpython-310.pyc │ └── video_utils.cpython-310.pyc ├── base_dataset.py ├── it_dataset.py ├── utils.py └── video_utils.py ├── evaluate_egoschema_result.py ├── example ├── 1917.mov ├── 1917.mp4 ├── bear.jpg ├── cooking.mp4 ├── dog.png ├── jesse_dance.mp4 ├── working.mp4 └── yoga.mp4 ├── models ├── __init__.py ├── __pycache__ │ └── __init__.cpython-310.pyc └── pllava │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── configuration_pllava.cpython-310.pyc │ ├── elastic_cache.cpython-310.pyc │ ├── llama.cpython-310.pyc │ ├── modeling_clip.cpython-310.pyc │ ├── modeling_flash_attention_utils.cpython-310.pyc │ ├── modeling_pllava.cpython-310.pyc │ ├── modeling_pllava_SF.cpython-310.pyc │ ├── modeling_pllava_flow.cpython-310.pyc │ ├── modify_llama.cpython-310.pyc │ ├── processing_pllava.cpython-310.pyc │ └── v433_modeling_llama.cpython-310.pyc │ ├── configuration_pllava.py │ ├── convert_pllava_weights_to_hf.py │ ├── elastic_cache.py │ ├── llama.py │ ├── llama_outlook.py │ ├── llava_arch.py │ ├── modeling_clip.py │ ├── modeling_flash_attention_utils.py │ ├── modeling_pllava.py │ ├── modeling_pllava_SF.py │ ├── modeling_pllava_flow.py │ ├── modify_llama.py │ ├── pllava_prumerge.py │ ├── processing_pllava.py │ └── v433_modeling_llama.py ├── requirements.no_torch.txt ├── requirements.torch.txt ├── requirements.txt ├── scripts ├── accel_config_deepspeed_zero2.yaml ├── accel_config_deepspeed_zero3_offload.yaml ├── accel_config_deepspeed_zero3_offload_multinode.yaml ├── accel_config_deepspeed_zero3_offload_multinode_1.yaml ├── accel_config_deepspeed_zero3_offload_multinode_2.yaml ├── accel_config_deepspeed_zero3_offload_singlegpu.yaml ├── accel_config_multigpu.yaml ├── accel_config_multinode.yaml ├── accel_config_singlegpu.yaml └── eval.sh ├── tasks ├── eval │ ├── __pycache__ │ │ ├── eval_utils.cpython-310.pyc │ │ ├── eval_utils.cpython-39.pyc │ │ └── model_utils.cpython-310.pyc │ ├── demo │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── pllava_demo.cpython-310.pyc │ │ ├── pllava_demo.py │ │ ├── show_compare.py │ │ └── show_gallery.py │ ├── egoshcema │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── pllava_eval_egoschema.cpython-310.pyc │ │ └── pllava_eval_egoschema.py │ ├── eval_utils.py │ ├── model_utils.py │ ├── mvbench │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── llava_next_video_mvbench.cpython-310.pyc │ │ │ ├── pllava_eval_mvbench.cpython-310.pyc │ │ │ └── tarsier_eval_mvbench.cpython-310.pyc │ │ └── pllava_eval_mvbench.py │ ├── recaption │ │ ├── __init__.py │ │ ├── pllava_recaption.py │ │ └── show_recaption.py │ ├── vcgbench │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── pllava_eval_vcgbench.cpython-310.pyc │ │ ├── pllava_eval_vcgbench.py │ │ └── show_vcg.py │ ├── videomme │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── pllava_eval_videomme.cpython-310.pyc │ │ └── pllava_eval_videomme.py │ ├── videoqabench │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── pllava_eval_videoqabench.cpython-310.pyc │ │ └── pllava_eval_videoqabench.py │ └── visualizer.py ├── shared_utils.py └── train │ ├── __pycache__ │ └── instruction_data.cpython-310.pyc │ ├── clever_process.py │ ├── config_pllava_nframe.py │ ├── config_pllava_nframe_yiprompt.py │ ├── ego_process.py │ ├── ffmpeg_tgif.py │ ├── instruction_data.py │ ├── k710_print.py │ ├── k710_process.py │ ├── missing_files.txt │ ├── mk_710.py │ ├── not_have.txt │ ├── output.mp4 │ ├── print_all_files.py │ ├── tgif_corrupt.txt │ ├── tgif_mp4.py │ ├── tgif_used.txt │ ├── train_pllava_flow_nframe_accel.py │ ├── train_pllava_nframe_accel.py │ ├── train_tarsier_flow_nframe_accel.py │ ├── train_tarsier_nframe_accel.py │ ├── vcg_process.py │ ├── vcg_read.py │ └── webvid_process.py └── utils ├── __pycache__ ├── basic_utils.cpython-310.pyc ├── config.cpython-310.pyc ├── config_utils.cpython-310.pyc ├── distributed.cpython-310.pyc ├── easydict.cpython-310.pyc ├── logger.cpython-310.pyc ├── optimizer.cpython-310.pyc └── scheduler.cpython-310.pyc ├── basic_utils.py ├── config.py ├── config_utils.py ├── distributed.py ├── easydict.py ├── logger.py ├── optimizer.py └── scheduler.py /README.md: -------------------------------------------------------------------------------- 1 | # ACL 2025 | PruneVid 2 | The official repository for paper "PruneVid: Visual Token Pruning for Efficient Video Large Language Models". 3 | 4 | Xiaohu Huang, Hao Zhou, Kai Han 5 | 6 | [`Webpage`](https://visual-ai.github.io/prunevid/) | [`Paper`](https://arxiv.org/abs/2412.16117v1) 7 | 8 | # Introduction 9 | ![Framework](assets/framework.png) 10 | 11 | We present PruneVid, a training-free visual token pruning method that enhances efficiency in multi-modal video understanding. By merging spatial-temporal tokens to reduce video redundancy and leveraging attention mechanisms within LLMs to retain only the visual tokens relevant to questions, PruneVid ensures high performance while reducing computational overhead. 12 | 13 | # Todo: 14 | - [x] Code release of PruneVid with PLLaVA. 15 | - [ ] Code release of PruneVid with LLaVA-OneVision. 16 | - [ ] Code release of PruneVid with ST-LLM. 17 | 18 | # License 19 | 20 | PruneVid is released under the [`CC BY-NC-SA 4.0 license`](https://creativecommons.org/licenses/by-nc-sa/4.0/). 21 | 22 | # Performance 23 | 24 | We conduct experiments on three video LLMs (PLLaVA, ST-LLM, and LLaVA-OneVision) under for benchmarks: MVBench, VideoMME, Egoschema, and VideoChatgpt-Bench (VCG-Bench). 25 | 26 | | Method | Retained Ratio | FLOPs (×) | MVBench | VideoMME | EgoSchema Subset / Fullset | TU | CU | CO | DO | CI | Avg | 27 | |---------------------------------|----------------|-----------|---------|----------|---------------------------|------|------|------|------|------|------| 28 | | PLLaVA | 100.0% | 1.00× | 46.6 | 44.4 | 47.8 / 42.6 | 2.33 | 3.62 | 2.93 | 2.86 | 3.21 | 2.99 | 29 | | PLLaVA w/ FastV | 30.0% | 0.33× | 46.1 | 43.6 | 46.2 / 41.0 | 2.38 | 3.49 | 2.89 | 2.76 | 3.14 | 2.93 | 30 | | PLLaVA w/ Prumerge | 55.7% | 0.53× | 45.6 | 43.8 | 45.2 / 40.4 | 2.34 | 3.52 | 2.90 | 2.76 | 3.15 | 2.93 | 31 | | PLLaVA w/ Look-M | 20.0% | 1.00× | 46.6 | 44.3 | 47.0 / 42.3 | 2.28 | 3.41 | 2.75 | 2.65 | 3.00 | 2.82 | 32 | | **PLLaVA w/ Ours** | **16.2%** | **0.23×** | **47.6**| **45.3** | **49.0 / 42.6** | **2.44** | **3.51** | **2.99** | **2.78** | **3.20** | **2.98** | 33 | | | | | | | | | | | | | | 34 | | ST-LLM | 100.0% | 1.00× | 54.9 | 42.0 | 56.2 / 45.6 | 2.46 | 3.46 | 2.66 | 2.63 | 3.08 | 2.86 | 35 | | ST-LLM w/ FastV | 30.0% | 0.37× | 42.9 | 34.5 | 48.0 / 38.5 | 2.01 | 2.23 | 1.55 | 1.94 | 1.69 | 1.88 | 36 | | ST-LLM w/ Look-M | 20.0% | 1.00× | 54.0 | 40.6 | 54.0 / 44.5 | 2.35 | 3.41 | 2.60 | 2.51 | 3.01 | 2.78 | 37 | | **ST-LLM w/ Ours** | **15.1%** | **0.26×** | **54.3**| **41.4** | **54.6 / 44.7** | **2.40** | **3.43** | **2.63** | **2.60** | **3.04** | **2.82** | 38 | | | | | | | | | | | | | | 39 | | LLaVA-OneVision | 100.0% | 1.00× | 58.0 | 58.2 | 62.0 / 60.0 | 2.75 | 3.70 | 3.39 | 2.97 | 3.50 | 3.26 | 40 | | LLaVA-OneVision w/ FastV | 30.0% | 0.30× | 57.2 | 57.6 | 62.6 / 60.0 | 2.65 | 3.61 | 3.28 | 2.85 | 3.39 | 3.16 | 41 | | LLaVA-OneVision w/ Prumerge | 55.2% | 0.49× | 52.9 | 56.7 | 62.2 / 60.0 | 2.72 | 3.64 | 3.32 | 2.94 | 3.44 | 3.21 | 42 | | LLaVA-OneVision w/ Look-M | 20.0% | 1.00× | 57.0 | 58.0 | 62.0 / **59.8** | 2.71 | 3.70 | 3.29 | 2.89 | 3.44 | 3.21 | 43 | | **LLaVA-OneVision w/ Ours** | **17.0%** | **0.20×** | **57.5**| **58.6** | **62.6 / 59.5** | **2.73** | **3.72** | **3.28** | **2.94** | **3.51** | **3.24** | 44 | 45 | # Data Preparation 46 | 47 | All four used benchmarks can be downloaded from huggingface website: [`MVBench`](https://huggingface.co/datasets/OpenGVLab/MVBench), [`VideoMME`](https://huggingface.co/datasets/lmms-lab/Video-MME), [`Egoschema`](https://huggingface.co/datasets/lmms-lab/egoschema), and [`VideoChatGPT-Bench`](https://huggingface.co/datasets/lmms-lab/VideoChatGPT). 48 | 49 | After downloading the datasets, please put them into the `DATAS` folder and sort out the source videos and annotations in the following formats: 50 | ``` 51 | DATAS/ 52 | ├── ego_schema/ 53 | │ ├── json/ 54 | │ └── videos/ 55 | ├── MVBench/ 56 | │ ├── json/ 57 | │ └── video/ 58 | ├── VCGBench/ 59 | │ ├── Videos/ 60 | │ ├── Zero_Shot_QA/ 61 | └── Video-MME/ 62 | ├── data/ 63 | └── json/ 64 | ``` 65 | 66 | # Pretrained Model 67 | 68 | The pretrained model can be found in their respective repositories: [`PLLaVA`](https://github.com/magic-research/PLLaVA?tab=readme-ov-file), [`ST-LLM`](https://github.com/TencentARC/ST-LLM/tree/main/stllm), and [`LLaVA-OneVision`](https://huggingface.co/lmms-lab). 69 | 70 | After downloading the models please put them into the MODELS folder: 71 | 72 | ``` 73 | MODELS/ 74 | ├── pllava-7b/ 75 | ``` 76 | 77 | # Environment Install 78 | 79 | We follow the environment installation guideline of [`PLLaVA`](https://github.com/magic-research/PLLaVA?tab=readme-ov-file). 80 | 81 | 1. Above all, the following environment set up is for python 3.10. If you choose to use conda for environment setup, we recommend creating the virtual environment with: 82 | ```bash 83 | conda create -n pllava python=3.10 84 | ``` 85 | 86 | 1. Firstly, install [pytorch](https://pytorch.org/) from the official website. The code runs on torch 2.2.1, cu118 or cu122. Select the version that suits your drive version. 87 | 88 | ``` 89 | torch 2.2.1+cu118 90 | torchaudio 2.2.1+cu118 91 | torchvision 0.17.1+cu118 92 | ``` 93 | 94 | If your driver version is higher than cu121, you could probably try installing with the following scripts: 95 | ```bash 96 | pip install -r requirements.txt 97 | ``` 98 | 99 | Otherwise, you would need to install a torch for your server first, then install the other packages: 100 | ```bash 101 | pip install -r requirements.torch.txt # decide your own requirements, (this is for cu11), or install torch directly following the official website. 102 | pip install -r requirements.no_torch.txt # install the following 103 | ``` 104 | 105 | # Evaluation 106 | 107 | As PruneVid is a training-free method, we can directly apply it on the pre-trained models. 108 | 109 | The provided scripts for evaluating model performance is given in `scripts/eval.sh`. Below is the script for evaluating the performance on MVBench, where you can edit the hyper-parameters whatever you want. The default setting is used in our paper. 110 | 111 | ``` 112 | lora_alpha=14 113 | selected_layers=(10) 114 | alphas=(0.4) 115 | taus=(0.8) 116 | temporal_segment_ratios=(0.25) 117 | cluster_ratios=(0.5) 118 | 119 | for alpha in "${alphas[@]}"; do 120 | for selected_layer in "${selected_layers[@]}"; do 121 | for tau in "${taus[@]}"; do 122 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do 123 | for cluster_ratio in "${cluster_ratios[@]}"; do 124 | # 执行命令 125 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio} 126 | mkdir -p "${SAVE_DIR}" 127 | conv_mode=eval_mvbench 128 | python -m tasks.eval.mvbench.pllava_eval_mvbench \ 129 | --pretrained_model_name_or_path ${model_dir} \ 130 | --save_path ${SAVE_DIR}/mvbench \ 131 | --num_frames ${num_frames} \ 132 | --use_lora \ 133 | --lora_alpha ${lora_alpha} \ 134 | --top_p 1.0 \ 135 | --temperature 1.0 \ 136 | --weight_dir ${weight_dir} \ 137 | --pooling_shape 16-12-12 \ 138 | --conv_mode ${conv_mode} \ 139 | --selected_layer ${selected_layer} \ 140 | --alpha ${alpha} \ 141 | --tau ${tau} \ 142 | --temporal_segment_ratio ${temporal_segment_ratio} \ 143 | --cluster_ratio ${cluster_ratio} 144 | done 145 | done 146 | done 147 | done 148 | done 149 | ``` 150 | 151 | As for Egoschema, which needs an external service to evaluate the model performance, we run the `evaluate_egoschema_result.py` for evaluation. Before executing the file, you should change the `root_dir` variable to your folder. 152 | ``` 153 | python evaluate_egoschema_result.py 154 | ``` 155 | 156 | # Acknowledgement 157 | 158 | This repository is built upon [`PLLaVA`](https://github.com/magic-research/PLLaVA?tab=readme-ov-file), [`ST-LLM`](https://github.com/TencentARC/ST-LLM/tree/main/stllm), and [`LLaVA-OneVision`](https://huggingface.co/lmms-lab). Thanks for those well-organized codebases. 159 | 160 | # Citation 161 | 162 | ```bibtex 163 | @inproceedings{ 164 | huang2024prunevid, 165 | title={PruneVid: Visual Token Pruning for Efficient Video Large Language Models}, 166 | author={Xiaohu Huang and Hao Zhou and Kai Han}, 167 | booktitle={arXiv}, 168 | year={2024} 169 | } 170 | ``` 171 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/assets/framework.png -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import ConcatDataset, DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms import InterpolationMode 5 | from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset 6 | 7 | 8 | def get_media_type(dataset_config): 9 | if len(dataset_config) == 3 and dataset_config[2] == "video": 10 | return "video" 11 | elif dataset_config[-1] == "only_video": 12 | return "only_video" 13 | else: 14 | return "image" 15 | 16 | 17 | def create_dataset(dataset_type, config): 18 | if "clip" in config.model.get("vit_model", 'vit'): 19 | mean = (0.485, 0.456, 0.406) 20 | std = (0.229, 0.224, 0.225) 21 | else: 22 | vision_enc_name = config.model.vision_encoder.name 23 | if "swin" in vision_enc_name or "vit" in vision_enc_name: 24 | mean = (0.485, 0.456, 0.406) 25 | std = (0.229, 0.224, 0.225) 26 | elif "beit" in vision_enc_name: 27 | mean = (0.5, 0.5, 0.5) # for all beit model except IN1K finetuning 28 | std = (0.5, 0.5, 0.5) 29 | elif "clip" in vision_enc_name: 30 | mean = (0.48145466, 0.4578275, 0.40821073) 31 | std = (0.26862954, 0.26130258, 0.27577711) 32 | else: 33 | raise ValueError 34 | 35 | normalize = transforms.Normalize(mean, std) 36 | 37 | # loaded images and videos are torch.Tensor of torch.uint8 format, 38 | # ordered as (T, 1 or 3, H, W) where T=1 for image 39 | type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) 40 | 41 | if config.inputs.video_input.random_aug: 42 | aug_transform = transforms.RandAugment() 43 | else: 44 | aug_transform = transforms.Lambda(lambda x: x) 45 | 46 | train_transform = transforms.Compose( 47 | [ 48 | aug_transform, 49 | # transforms.RandomResizedCrop( 50 | # config.inputs.image_res, 51 | # scale=(0.5, 1.0), 52 | # interpolation=InterpolationMode.BICUBIC, 53 | # ), 54 | # transforms.RandomHorizontalFlip(), 55 | transforms.Resize( 56 | (config.inputs.image_res, config.inputs.image_res), 57 | interpolation=InterpolationMode.BICUBIC, 58 | ), 59 | type_transform, 60 | normalize, 61 | ] 62 | ) 63 | test_transform = transforms.Compose( 64 | [ 65 | transforms.Resize( 66 | (config.inputs.image_res, config.inputs.image_res), 67 | interpolation=InterpolationMode.BICUBIC, 68 | ), 69 | type_transform, 70 | normalize, 71 | ] 72 | ) 73 | 74 | video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") 75 | video_only_dataset_kwargs_train = dict( 76 | video_reader_type=video_reader_type, 77 | sample_type=config.inputs.video_input.sample_type, 78 | num_frames=config.inputs.video_input.num_frames, 79 | num_tries=3, # false tolerance 80 | ) 81 | 82 | if dataset_type == "pt_train": 83 | raise ValueError("NOT PRETRAINING YET") 84 | elif dataset_type in ["it_train"]: 85 | # convert to list of lists 86 | train_files = ( 87 | [config.train_file] if isinstance(config.train_file[0], str) else config.train_file 88 | ) 89 | train_media_types = sorted(list({get_media_type(e) for e in train_files})) 90 | 91 | train_datasets = [] 92 | for m in train_media_types: 93 | dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset 94 | # dataset of the same media_type will be mixed in a single Dataset object 95 | _train_files = [e for e in train_files if get_media_type(e) == m] 96 | 97 | datasets = [] 98 | for train_file in _train_files: 99 | dataset_kwargs = dict( 100 | ann_file=train_file, 101 | transform=train_transform, 102 | mm_alone=config.preprocess.get("mm_alone", True), 103 | add_second_msg=config.preprocess.get("add_second_msg", True), 104 | skip_short_sample=config.preprocess.get("skip_short_sample", False), 105 | clip_transform=config.preprocess.get("clip_transform", False), 106 | random_shuffle=config.preprocess.get("random_shuffle", True), 107 | system=config.preprocess.get("system", ""), 108 | role=config.preprocess.get('roles', ("Human", "Assistant")), 109 | end_signal=config.preprocess.get('end_signal', "###"), 110 | begin_signal=config.preprocess.get('begin_signal', ""), 111 | ) 112 | if m == "video": 113 | video_only_dataset_kwargs_train.update({ 114 | "start_token": config.model.get("start_token", ""), 116 | }) 117 | dataset_kwargs.update(video_only_dataset_kwargs_train) 118 | if "tgif" in train_file[1]: 119 | video_only_dataset_kwargs_train.update({ 120 | "video_reader_type": "gif" 121 | }) 122 | dataset_kwargs.update(video_only_dataset_kwargs_train) 123 | elif "webvid" in train_file[1]: 124 | video_only_dataset_kwargs_train.update({ 125 | # "video_reader_type": "hdfs" 126 | "video_reader_type": "decord" 127 | }) 128 | else: 129 | video_only_dataset_kwargs_train.update({ 130 | "video_reader_type": "decord" 131 | }) 132 | dataset_kwargs.update(video_only_dataset_kwargs_train) 133 | datasets.append(dataset_cls(**dataset_kwargs)) 134 | dataset = ConcatDataset(datasets) 135 | train_datasets.append(dataset) 136 | return train_datasets 137 | 138 | 139 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 140 | loaders = [] 141 | for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( 142 | datasets, samplers, batch_size, num_workers, is_trains, collate_fns 143 | ): 144 | if is_train: 145 | shuffle = sampler is None 146 | drop_last = True 147 | else: 148 | shuffle = False 149 | drop_last = False 150 | loader = DataLoader( 151 | dataset, 152 | batch_size=bs, 153 | num_workers=n_worker, 154 | pin_memory=False, 155 | sampler=sampler, 156 | shuffle=shuffle, 157 | collate_fn=collate_fn, 158 | drop_last=drop_last, 159 | persistent_workers=True if n_worker > 0 else False, 160 | ) 161 | loaders.append(loader) 162 | return loaders 163 | 164 | -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/base_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/dataset/__pycache__/base_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/it_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/dataset/__pycache__/it_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/dataset/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/video_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/dataset/__pycache__/video_utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import random 5 | from torch.utils.data import Dataset 6 | import time 7 | from dataset.utils import load_image_from_path 8 | from mmflow.apis import init_model 9 | import torch 10 | import string 11 | import numpy as np 12 | import cv2 13 | try: 14 | from petrel_client.client import Client 15 | has_client = True 16 | except ImportError: 17 | has_client = False 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class ImageVideoBaseDataset(Dataset): 23 | """Base class that implements the image and video loading methods""" 24 | 25 | media_type = "video" 26 | 27 | def __init__(self): 28 | assert self.media_type in ["image", "video", "only_video"] 29 | self.data_root = None 30 | self.anno_list = ( 31 | None # list(dict), each dict contains {"image": str, # image or video path} 32 | ) 33 | self.transform = None 34 | self.video_reader = None 35 | self.num_tries = None 36 | # self.optical_model = self.get_optical_flow_model() 37 | 38 | self.client = None 39 | if has_client: 40 | self.client = Client('~/petreloss.conf') 41 | 42 | def generate_random_string(self, length=6): 43 | characters = string.ascii_letters + string.digits 44 | 45 | # 随机选择字符并组成字符串 46 | random_string = ''.join(random.choice(characters) for _ in range(length)) 47 | 48 | return random_string 49 | 50 | def extract_flow_raft(self, frames, model): 51 | # t c h w 52 | t, c, ori_h, ori_w = frames.shape 53 | 54 | assert c == 3 55 | 56 | with torch.no_grad(): 57 | # name = self.generate_random_string() 58 | # frames_npy = frames.detach().cpu().numpy() 59 | # np.save(f'frame/frame_{name}.npy', frames_npy) 60 | 61 | frames = (frames - 127.5) / 127.5 62 | 63 | frames = frames.cuda() 64 | 65 | feat = model.encoder(frames) # t c h w 66 | 67 | cxt_feat = model.context(frames) 68 | 69 | h_feat, cxt_feat = torch.split( 70 | cxt_feat, [model.h_channels, model.cxt_channels], dim=1) 71 | h_feat = torch.tanh(h_feat) 72 | cxt_feat = torch.relu(cxt_feat) 73 | 74 | t, c, h, w = feat.shape 75 | feat = feat.view(t, c, h, w) 76 | pre_feat = feat.clone() 77 | next_feat = torch.cat([feat[1:], feat[-2].unsqueeze(0)], dim=0) # t c h w 78 | next_feat = next_feat.contiguous().view(-1, c, h, w) 79 | 80 | flow = torch.zeros((next_feat.shape[0], 2, h, w), device=next_feat.device) 81 | 82 | upflow_preds = model.decoder(pre_feat, next_feat, flow, h_feat, cxt_feat) 83 | 84 | flow_result = upflow_preds[-1].cpu() 85 | 86 | print(flow_result.shape) 87 | 88 | # flow_result_npy = flow_result.detach().cpu().numpy() 89 | # np.save(f'flow/flow_{name}.npy', flow_result_npy) 90 | 91 | return flow_result 92 | 93 | def __getitem__(self, index): 94 | raise NotImplementedError 95 | 96 | def __len__(self): 97 | raise NotImplementedError 98 | 99 | def get_anno(self, index): 100 | """obtain the annotation for one media (video or image) 101 | 102 | Args: 103 | index (int): The media index. 104 | 105 | Returns: dict. 106 | - "image": the filename, video also use "image". 107 | - "caption": The caption for this file. 108 | 109 | """ 110 | anno = self.anno_list[index] 111 | if self.data_root is not None: 112 | anno["image"] = os.path.join(self.data_root, anno["image"]) 113 | return anno 114 | 115 | def load_and_transform_media_data(self, index, data_path): 116 | if self.media_type == "image": 117 | return self.load_and_transform_media_data_image(index, data_path, clip_transform=self.clip_transform) 118 | else: 119 | return self.load_and_transform_media_data_video(index, data_path, clip_transform=self.clip_transform) 120 | 121 | def load_and_transform_media_data_image(self, index, data_path, clip_transform=False): 122 | image = load_image_from_path(data_path, client=self.client) 123 | if not clip_transform: 124 | image = self.transform(image) 125 | return image, index 126 | 127 | def get_optical_flow_model(device): 128 | # config = '~/.cache/mim/pwcnet_ft_4x1_300k_sintel_final_384x768.py' 129 | # checkpoint = '~/.cache/mim/pwcnet_ft_4x1_300k_sintel_final_384x768.pth' 130 | config = '~/.cache/mim/raft_8x2_100k_mixed_368x768.py' 131 | checkpoint = '~/.cache/mim/raft_8x2_100k_mixed_368x768.pth' 132 | model = init_model(config, checkpoint) 133 | return model 134 | 135 | def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, clip_transform=False): 136 | for _ in range(self.num_tries): 137 | flow = None 138 | try: 139 | max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 140 | if "webvid" in data_path: 141 | # hdfs_dir="hdfs://harunava/home/byte_ailab_us_cvg/user/weimin.wang/videogen_data/webvid_data/10M_full_train" 142 | # video_name = os.path.basename(data_path) 143 | # video_id, extension = os.path.splitext(video_name) 144 | # ind_file = os.path.join(hdfs_dir, self.keys_indexfile[video_id]) 145 | # frames, frame_indices, fps = self.video_reader(ind_file, video_id, self.num_frames, self.sample_type, 146 | # max_num_frames=max_num_frames, client=self.client, clip=clip) 147 | frames, frame_indices, fps = self.video_reader( 148 | data_path, self.num_frames, self.sample_type, 149 | max_num_frames=max_num_frames, client=self.client, clip=clip 150 | ) 151 | else: 152 | frames, frame_indices, fps = self.video_reader( 153 | data_path, self.num_frames, self.sample_type, 154 | max_num_frames=max_num_frames, client=self.client, clip=clip 155 | ) 156 | 157 | # flow = self.extract_flow_raft(frames, self.optical_model) 158 | 159 | except Exception as e: 160 | logger.warning( 161 | f"Caught exception {e} when loading video {data_path}, " 162 | f"randomly sample a new video as replacement" 163 | ) 164 | index = random.randint(0, len(self) - 1) 165 | ann = self.get_anno(index) 166 | data_path = ann["image"] 167 | continue 168 | # shared aug for video frames 169 | if not clip_transform: 170 | frames = self.transform(frames) 171 | # if flow is not None and not clip_transform: 172 | # frames = torch.cat([frames, flow], dim=0) 173 | if return_fps: 174 | sec = [str(round(f / fps, 1)) for f in frame_indices] 175 | return frames, index, sec 176 | else: 177 | return frames, index 178 | else: 179 | raise RuntimeError( 180 | f"Failed to fetch video after {self.num_tries} tries. " 181 | f"This might indicate that you have many corrupted videos." 182 | ) 183 | -------------------------------------------------------------------------------- /dataset/it_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import sqlite3 5 | import random 6 | from os.path import basename 7 | 8 | import numpy as np 9 | import datetime 10 | 11 | from dataset.base_dataset import ImageVideoBaseDataset 12 | from dataset.video_utils import VIDEO_READER_FUNCS 13 | 14 | logger = logging.getLogger(__name__) 15 | IMAGE_TOKEN="" 16 | 17 | class ITImgTrainDataset(ImageVideoBaseDataset): 18 | media_type = "image" 19 | 20 | def __init__( 21 | self, ann_file, transform, 22 | system="", role=("Human", "Assistant"), 23 | mm_alone=True, 24 | add_second_msg=True, 25 | start_token="", end_token="", 26 | random_shuffle=True, # if True, shuffle the QA list ##xl:????? why need random shuffle 27 | begin_signal=None, 28 | end_signal=None, 29 | clip_transform=False, 30 | skip_short_sample=False, 31 | ): 32 | super().__init__() 33 | self.mm_alone = mm_alone 34 | self.clip_transform = clip_transform 35 | if len(ann_file) == 3 and ann_file[2] == "video": 36 | self.media_type = "video" 37 | else: 38 | self.media_type = "image" 39 | self.label_file, self.data_root = ann_file[:2] 40 | 41 | logger.info('Load json file') 42 | with open(self.label_file, 'r') as f: 43 | self.anno = json.load(f) 44 | self.num_examples = len(self.anno) 45 | self.transform = transform 46 | annos = [] 47 | from tqdm import tqdm 48 | for ann in self.anno: 49 | filename = ann['video'] if 'video' in ann else ann['image'] 50 | if self.media_type =='video' and "webvid" in self.data_root: 51 | # video_id, extension = os.path.splitext(os.path.basename(filename)) 52 | # if video_id not in self.keys_indexfile: 53 | # pass 54 | # else: 55 | # annos.append(ann) 56 | if filename is None or filename=="None": 57 | pass 58 | else: 59 | annos.append(ann) 60 | # if os.path.exists(os.path.join(self.data_root, filename)): 61 | # annos.append(ann) 62 | # else: 63 | # ... 64 | else: 65 | 66 | if filename is None or filename=="None": 67 | pass 68 | else: 69 | if os.path.exists(os.path.join(self.data_root, filename)): 70 | annos.append(ann) 71 | else: 72 | ... 73 | print('examples:', len(annos), len(self.anno), ann_file) 74 | if len(annos) / len(self.anno) < 0.9: 75 | raise ValueError(f"{len(annos)}/{len(self.anno)}") 76 | 77 | self.anno = annos 78 | self.num_examples = len(self.anno) 79 | 80 | 81 | # prompt parameters 82 | if system: 83 | assert system[-1] == " ", "' ' should be add in the end of system, thus '###' will be tokenized into one token." 84 | # currently not support add start_token and end_token in the system, since the msg should be added properly 85 | self.begin_signal = [begin_signal for _ in role] if isinstance(begin_signal, str) else begin_signal 86 | self.end_signal = [end_signal for _ in role] if isinstance(end_signal, str) else end_signal 87 | self.start_token = start_token 88 | self.end_token = end_token 89 | self.system = system 90 | self.role = role 91 | self.random_shuffle = random_shuffle 92 | # instruction location and number 93 | logger.info(f"Random shuffle: {self.random_shuffle}") 94 | 95 | def get_anno(self, index): 96 | filename = self.anno[index][self.media_type] 97 | qa = self.anno[index]["QA"] 98 | 99 | if "start" in self.anno[index] and "end" in self.anno[index]: 100 | anno = { 101 | "image": os.path.join(self.data_root, filename), "qa": qa, 102 | "start": self.anno[index]["start"], "end": self.anno[index]["end"], 103 | } 104 | else: 105 | anno = {"image": os.path.join(self.data_root, filename), "qa": qa} 106 | return anno 107 | 108 | def __len__(self): 109 | return self.num_examples 110 | 111 | def process_qa(self, qa, msg=""): 112 | cur_instruction = "" 113 | # randomly shuffle qa for conversation 114 | if self.random_shuffle and len(qa) > 1: 115 | random.shuffle(qa) 116 | if "i" in qa[0].keys() and qa[0]["i"] != "": 117 | cur_instruction = qa[0]["i"] + self.end_signal[0] 118 | 119 | conversation = self.system 120 | # add instruction as system message 121 | if cur_instruction: 122 | conversation += cur_instruction 123 | 124 | # rstrip() for the extra " " in msg 125 | if self.mm_alone: 126 | conversation += ( 127 | self.begin_signal[0] + self.role[0] + 128 | self.start_token + self.end_token + msg.rstrip() + self.end_signal[0] 129 | ) 130 | 131 | for i, sentence in enumerate(qa): 132 | q = self.start_token + self.end_token+"\n"+ qa[0]["q"] if (not self.mm_alone) and (i == 0) else sentence["q"] 133 | a = sentence["a"] 134 | if q != "": 135 | conversation += (self.begin_signal[0] + self.role[0] + q + self.end_signal[1]) 136 | else: 137 | # no question, often in caption dataset 138 | pass 139 | conversation += (self.begin_signal[0] + self.role[1] + a + self.end_signal[1]) 140 | 141 | 142 | if cur_instruction: 143 | cur_instruction += qa[0]["q"] 144 | return conversation, cur_instruction.strip() 145 | 146 | def __getitem__(self, index): 147 | try: 148 | ann = self.get_anno(index) 149 | image, index = self.load_and_transform_media_data_image(index, ann["image"], clip_transform=self.clip_transform) 150 | conversation, instruction = self.process_qa(ann["qa"]) 151 | return image, conversation, instruction, index 152 | except Exception as e: 153 | logger.warning(f"Caught exception {e} when loading image {ann['image']}") 154 | index = np.random.randint(0, len(self)) 155 | return self.__getitem__(index) 156 | 157 | 158 | class ITVidTrainDataset(ITImgTrainDataset): 159 | media_type = "video" 160 | 161 | def __init__( 162 | self, ann_file, transform, 163 | num_frames=4, video_reader_type="decord", sample_type="rand", num_tries=3, 164 | mm_alone=True, 165 | system="", role=("Human", "Assistant"), 166 | start_token="", 167 | add_second_msg=True, 168 | random_shuffle=True, 169 | begin_signal=None, 170 | end_signal=None, 171 | clip_transform=False, 172 | skip_short_sample=False, 173 | 174 | ): 175 | # "id index file for webvid" 176 | # if "webvid" in ann_file[1]: 177 | # with open("/mnt/bn/dq-storage-ckpt/xulin/datasets/videos/webvid_10m/keys_indexfile.json") as f: 178 | # self.keys_indexfile = json.load(f) # the correponding index file for each webvid id 179 | 180 | super().__init__( 181 | ann_file, transform, 182 | system=system, role=role, 183 | mm_alone=mm_alone, 184 | start_token=start_token, end_token=end_token, 185 | random_shuffle=random_shuffle, 186 | begin_signal=begin_signal, 187 | end_signal=end_signal, 188 | clip_transform=clip_transform, 189 | skip_short_sample=skip_short_sample, 190 | ) 191 | self.num_frames = num_frames 192 | self.video_reader_type = video_reader_type 193 | self.video_reader = VIDEO_READER_FUNCS[video_reader_type] 194 | self.sample_type = sample_type 195 | self.num_tries = num_tries 196 | self.add_second_msg = add_second_msg 197 | 198 | logger.info(f"Use {video_reader_type} for data in {ann_file}") 199 | if add_second_msg: 200 | logger.info(f"Add second message: The video contains X frames sampled at T seconds.") 201 | 202 | def __getitem__(self, index): 203 | try: 204 | ann = self.get_anno(index) 205 | 206 | msg = "" 207 | clip = None 208 | if "start" in ann and "end" in ann: 209 | clip = [ann["start"], ann["end"]] 210 | video, index, sec = self.load_and_transform_media_data_video(index, ann["image"], return_fps=True, clip=clip, clip_transform=self.clip_transform) 211 | if self.add_second_msg: 212 | # " " should be added in the start and end 213 | msg = f" The video contains {len(sec)} frames sampled at {', '.join(sec)} seconds. " 214 | conversation, instruction = self.process_qa(ann["qa"], msg) 215 | return video, conversation, instruction, index 216 | except Exception as e: 217 | logger.warning(f"Caught exception {e} when loading video {ann['image']}") 218 | index = np.random.randint(0, len(self)) 219 | return self.__getitem__(index) -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | from utils.distributed import is_main_process, get_rank, get_world_size 2 | import io 3 | import json 4 | import re 5 | import numpy as np 6 | from os.path import join 7 | from tqdm import trange 8 | from PIL import Image 9 | from PIL import ImageFile 10 | from torchvision.transforms import PILToTensor 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | Image.MAX_IMAGE_PIXELS = None 13 | 14 | 15 | def load_image_from_path(image_path, client): 16 | if image_path.startswith('s3') or image_path.startswith('p2'): 17 | value = client.Get(image_path) 18 | img_bytes = np.frombuffer(value, dtype=np.uint8) 19 | buff = io.BytesIO(img_bytes) 20 | image = Image.open(buff).convert('RGB') 21 | else: 22 | image = Image.open(image_path).convert('RGB') # PIL Image 23 | image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 24 | return image 25 | 26 | def pre_text(text, max_l=None, pre_text=True): 27 | if pre_text: 28 | text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) 29 | text = text.replace('-', ' ').replace('/', ' ').replace('', 'person') 30 | 31 | text = re.sub(r"\s{2,}", ' ', text) 32 | text = text.rstrip('\n').strip(' ') 33 | 34 | if max_l: # truncate 35 | words = text.split(' ') 36 | if len(words) > max_l: 37 | text = ' '.join(words[:max_l]) 38 | else: 39 | pass 40 | return text 41 | 42 | -------------------------------------------------------------------------------- /dataset/video_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py 3 | """ 4 | import random 5 | import io 6 | import os 7 | import av 8 | import cv2 9 | import decord 10 | import imageio 11 | from decord import VideoReader 12 | 13 | # from dataloader import KVReader 14 | import torch 15 | import numpy as np 16 | import math 17 | # import tensorflow as tf 18 | decord.bridge.set_bridge("torch") 19 | 20 | import logging 21 | logger = logging.getLogger(__name__) 22 | 23 | def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: 24 | """ 25 | Converts a present time with the given time base and start_pts offset to seconds. 26 | 27 | Returns: 28 | time_in_seconds (float): The corresponding time in seconds. 29 | 30 | https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 31 | """ 32 | if pts == math.inf: 33 | return math.inf 34 | 35 | return int(pts - start_pts) * time_base 36 | 37 | 38 | def get_pyav_video_duration(video_reader): 39 | video_stream = video_reader.streams.video[0] 40 | video_duration = pts_to_secs( 41 | video_stream.duration, 42 | video_stream.time_base, 43 | video_stream.start_time 44 | ) 45 | return float(video_duration) 46 | 47 | 48 | def get_frame_indices_by_fps(): 49 | pass 50 | 51 | 52 | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): 53 | if sample in ["rand", "middle"]: # uniform sampling 54 | acc_samples = min(num_frames, vlen) 55 | # split the video into `acc_samples` intervals, and sample from each interval. 56 | intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) 57 | ranges = [] 58 | for idx, interv in enumerate(intervals[:-1]): 59 | ranges.append((interv, intervals[idx + 1] - 1)) 60 | if sample == 'rand': 61 | try: 62 | frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] 63 | except: 64 | frame_indices = np.random.permutation(vlen)[:acc_samples] 65 | frame_indices.sort() 66 | frame_indices = list(frame_indices) 67 | elif fix_start is not None: 68 | frame_indices = [x[0] + fix_start for x in ranges] 69 | elif sample == 'middle': 70 | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] 71 | else: 72 | raise NotImplementedError 73 | 74 | if len(frame_indices) < num_frames: # padded with last frame 75 | padded_frame_indices = [frame_indices[-1]] * num_frames 76 | padded_frame_indices[:len(frame_indices)] = frame_indices 77 | frame_indices = padded_frame_indices 78 | elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps 79 | output_fps = float(sample[3:]) 80 | duration = float(vlen) / input_fps 81 | delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents 82 | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) 83 | frame_indices = np.around(frame_seconds * input_fps).astype(int) 84 | frame_indices = [e for e in frame_indices if e < vlen] 85 | if max_num_frames > 0 and len(frame_indices) > max_num_frames: 86 | frame_indices = frame_indices[:max_num_frames] 87 | # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) 88 | else: 89 | raise ValueError 90 | return frame_indices 91 | 92 | 93 | def read_frames_av( 94 | video_path, num_frames, sample='rand', fix_start=None, 95 | max_num_frames=-1, client=None, clip=None, 96 | ): 97 | reader = av.open(video_path) 98 | frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] 99 | vlen = len(frames) 100 | duration = get_pyav_video_duration(reader) 101 | fps = vlen / float(duration) 102 | frame_indices = get_frame_indices( 103 | num_frames, vlen, sample=sample, fix_start=fix_start, 104 | input_fps=fps, max_num_frames=max_num_frames 105 | ) 106 | frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8 107 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 108 | return frames, frame_indices, fps 109 | 110 | 111 | def read_frames_gif( 112 | video_path, num_frames, sample='rand', fix_start=None, 113 | max_num_frames=-1, client=None, clip=None, 114 | ): 115 | if video_path.startswith('s3') or video_path.startswith('p2'): 116 | video_bytes = client.get(video_path) 117 | gif = imageio.get_reader(io.BytesIO(video_bytes)) 118 | else: 119 | gif = imageio.get_reader(video_path) 120 | vlen = len(gif) 121 | frame_indices = get_frame_indices( 122 | num_frames, vlen, sample=sample, fix_start=fix_start, 123 | max_num_frames=max_num_frames 124 | ) 125 | frames = [] 126 | for index, frame in enumerate(gif): 127 | # for index in frame_idxs: 128 | if index in frame_indices: 129 | frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) 130 | frame = torch.from_numpy(frame).byte() 131 | # # (H x W x C) to (C x H x W) 132 | frame = frame.permute(2, 0, 1) 133 | frames.append(frame) 134 | frames = torch.stack(frames) # .float() / 255 135 | 136 | return frames, frame_indices, 25. # for tgif 137 | 138 | 139 | def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None, 140 | max_num_frames=-1, client=None, clip=None): 141 | _context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)} 142 | _sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)} 143 | num_parallel_reader = 1 144 | filename, extension = os.path.splitext(ind_file) 145 | reader = KVReader(filename, num_parallel_reader) 146 | key = vid 147 | values = reader.read_many([key]) 148 | item = values[0] 149 | contexts, sequences = tf.io.parse_single_sequence_example( 150 | serialized=item, 151 | context_features=_context_features, 152 | sequence_features=_sequence_features) 153 | 154 | # text = contexts['title'].numpy().decode("utf-8") 155 | rawframes = sequences['data'] 156 | vlen = len(rawframes) 157 | sample="rand" 158 | 159 | frame_indices = get_frame_indices(num_frames, vlen, sample=sample, 160 | fix_start=fix_start, 161 | max_num_frames=max_num_frames) 162 | def read_image(raw_data): 163 | return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy() 164 | 165 | frames = [] 166 | for index, frame in enumerate(rawframes): 167 | if index in frame_indices: 168 | frame = read_image(frame) 169 | frame = torch.as_tensor(frame) 170 | frames.append(frame) 171 | 172 | frames = torch.stack(frames) 173 | # print("in hdfs========>",frames[0]) 174 | frames = frames.permute(0, 3, 1, 2) 175 | return frames, frame_indices, 25 # don't know the fps for index 176 | 177 | 178 | def read_frames_decord( 179 | video_path, num_frames, sample='rand', fix_start=None, 180 | max_num_frames=-1, client=None, clip=None 181 | ): 182 | if video_path.startswith('s3') or video_path.startswith('p2'): 183 | video_bytes = client.get(video_path) 184 | video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) 185 | else: 186 | video_reader = VideoReader(video_path, num_threads=1) 187 | vlen = len(video_reader) 188 | fps = video_reader.get_avg_fps() 189 | duration = vlen / float(fps) 190 | 191 | if clip: 192 | start, end = clip 193 | duration = end - start 194 | vlen = int(duration * fps) 195 | start_index = int(start * fps) 196 | 197 | frame_indices = get_frame_indices( 198 | num_frames, vlen, sample=sample, fix_start=fix_start, 199 | input_fps=fps, max_num_frames=max_num_frames 200 | ) 201 | if clip: 202 | frame_indices = [f + start_index for f in frame_indices] 203 | 204 | frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 205 | frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 206 | return frames, frame_indices, float(fps) 207 | 208 | 209 | VIDEO_READER_FUNCS = { 210 | 'av': read_frames_av, 211 | 'decord': read_frames_decord, 212 | 'gif': read_frames_gif, 213 | 'hdfs': read_frames_hdfs, 214 | } 215 | -------------------------------------------------------------------------------- /evaluate_egoschema_result.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import argparse 3 | import requests 4 | 5 | root_dir = 'test_results/pllava-7b-lora14-threshold0.8-layer10-alpha0.4-temporal-segment-ratio-0.25-cluster-ratio-0.5/egoschema' 6 | 7 | def extract_and_convert(label_string): 8 | # 创建一个字典来映射字母到数字 9 | mapping = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4} 10 | 11 | # 提取字符串中的第一个字符 12 | first_char = label_string[1] 13 | 14 | # 确保字符在映射范围内 15 | if first_char in mapping: 16 | return mapping[first_char] 17 | else: 18 | raise ValueError("Input string does not start with a valid label (A-E).") 19 | 20 | def send_post_request(data): 21 | """ 22 | Sends a POST request to the specified URL with the given JSON file. 23 | 24 | Parameters: 25 | - json_file (str): Path to the JSON file to be used in the request body. 26 | 27 | Returns: 28 | - Response object containing server's response. 29 | """ 30 | 31 | url = "https://validation-server.onrender.com/api/upload/" 32 | headers = { 33 | "Content-Type": "application/json" 34 | } 35 | 36 | response = requests.post(url, headers=headers, json=data) 37 | 38 | return response 39 | 40 | predition_jsonls = [f for f in os.listdir(root_dir) if 'all_results' in f] 41 | 42 | result_dict = {} 43 | 44 | for pred_jsonl in predition_jsonls: 45 | data_list = json.load(open(os.path.join(root_dir, pred_jsonl), 'r'))['result_list'] 46 | for data in data_list: 47 | pred = data['pred'] 48 | pred = extract_and_convert(pred) 49 | vid = data['video_path'].split('/')[-1].split('.')[0] 50 | result_dict[vid] = pred 51 | # with open(os.path.join(root_dir, pred_jsonl), 'r') as f: 52 | # lines = f.readlines() 53 | # for line in lines: 54 | # data = json.loads(line) 55 | # result_dict[data['vid']] = extract_and_convert(data['text']['prediction']) 56 | print(result_dict) 57 | response = send_post_request(result_dict) 58 | print(f"Response Status Code: {response.status_code}") 59 | print(f"Response Content:\n{response.text}") -------------------------------------------------------------------------------- /example/1917.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/1917.mov -------------------------------------------------------------------------------- /example/1917.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/1917.mp4 -------------------------------------------------------------------------------- /example/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/bear.jpg -------------------------------------------------------------------------------- /example/cooking.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/cooking.mp4 -------------------------------------------------------------------------------- /example/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/dog.png -------------------------------------------------------------------------------- /example/jesse_dance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/jesse_dance.mp4 -------------------------------------------------------------------------------- /example/working.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/working.mp4 -------------------------------------------------------------------------------- /example/yoga.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/example/yoga.mp4 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available 17 | 18 | # from .modeling_pllava_flow import PllavaFlowForConditionalGeneration 19 | from .modeling_pllava import PllavaForConditionalGeneration 20 | from .modeling_pllava_SF import PllavaSFForConditionalGeneration 21 | from .processing_pllava import PllavaProcessor 22 | from .configuration_pllava import PllavaConfig 23 | 24 | # _import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]} 25 | 26 | # try: 27 | # if not is_torch_available(): 28 | # raise OptionalDependencyNotAvailable() 29 | # except OptionalDependencyNotAvailable: 30 | # pass 31 | # else: 32 | # _import_structure["modeling_pllava"] = [ 33 | # "PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", 34 | # "PllavaForConditionalGeneration", 35 | # "PllavaPreTrainedModel", 36 | # ] 37 | # _import_structure["processing_pllava"] = ["PllavaProcessor"] 38 | 39 | 40 | # if TYPE_CHECKING: 41 | # from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig 42 | 43 | # try: 44 | # if not is_torch_available(): 45 | # raise OptionalDependencyNotAvailable() 46 | # except OptionalDependencyNotAvailable: 47 | # pass 48 | # else: 49 | # from .modeling_pllava import ( 50 | # PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, 51 | # PllavaForConditionalGeneration, 52 | # PllavaPreTrainedModel, 53 | # ) 54 | # from .processing_pllava import PllavaProcessor 55 | 56 | 57 | # else: 58 | # import sys 59 | 60 | # sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) 61 | -------------------------------------------------------------------------------- /models/pllava/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/configuration_pllava.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/configuration_pllava.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/elastic_cache.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/elastic_cache.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/llama.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/modeling_clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_clip.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/modeling_flash_attention_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_flash_attention_utils.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/modeling_pllava.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_pllava.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/modeling_pllava_SF.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_pllava_SF.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/modeling_pllava_flow.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modeling_pllava_flow.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/modify_llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/modify_llama.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/processing_pllava.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/processing_pllava.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/__pycache__/v433_modeling_llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/models/pllava/__pycache__/v433_modeling_llama.cpython-310.pyc -------------------------------------------------------------------------------- /models/pllava/configuration_pllava.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ Llava model configuration""" 15 | 16 | from transformers.configuration_utils import PretrainedConfig 17 | from transformers.utils import logging 18 | from transformers.models.auto import CONFIG_MAPPING 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json", 25 | } 26 | 27 | 28 | class PllavaConfig(PretrainedConfig): 29 | r""" 30 | This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an 31 | Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration 32 | with the defaults will yield a similar configuration to that of the Llava-9B. 33 | 34 | e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b) 35 | 36 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 37 | documentation from [`PretrainedConfig`] for more information. 38 | 39 | Args: 40 | vision_config (`LlavaVisionConfig`, *optional*): 41 | Custom vision config or dict 42 | text_config (`Union[AutoConfig, dict]`, *optional*): 43 | The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. 44 | ignore_index (`int`, *optional*, defaults to -100): 45 | The ignore index for the loss function. 46 | image_token_index (`int`, *optional*, defaults to 32000): 47 | The image token index to encode the image prompt. 48 | projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): 49 | The activation function used by the multimodal projector. 50 | vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): 51 | The feature selection strategy used to select the vision feature from the CLIP backbone. 52 | vision_feature_layer (`int`, *optional*, defaults to -2): 53 | The index of the layer to select the vision feature. 54 | vocab_size (`int`, *optional*, defaults to 32000): 55 | Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the 56 | `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`] 57 | 58 | Example: 59 | 60 | ```python 61 | >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig 62 | 63 | >>> # Initializing a CLIP-vision config 64 | >>> vision_config = CLIPVisionConfig() 65 | 66 | >>> # Initializing a Llama config 67 | >>> text_config = LlamaConfig() 68 | 69 | >>> # Initializing a Llava llava-1.5-7b style configuration 70 | >>> configuration = LlavaConfig(vision_config, text_config) 71 | 72 | >>> # Initializing a model from the llava-1.5-7b style configuration 73 | >>> model = LlavaForConditionalGeneration(configuration) 74 | 75 | >>> # Accessing the model configuration 76 | >>> configuration = model.config 77 | ```""" 78 | 79 | model_type = "llava" 80 | is_composition = False 81 | 82 | def __init__( 83 | self, 84 | vision_config=None, 85 | text_config=None, 86 | ignore_index=-100, 87 | image_token_index=32000, 88 | projector_hidden_act="gelu", 89 | vision_feature_select_strategy="default", 90 | vision_feature_layer=-2, 91 | vocab_size=32000, 92 | pooling_method='avg', 93 | pooling_shape=(8, 16, 16), 94 | frame_shape=(24, 24), # llava 1.5 pretrained frame shape 95 | num_frames=1, # llava 1.5 pretrained frame shape 96 | use_pooling=True, 97 | gradient_checkpointing=False, 98 | selected_layer=10, 99 | alpha=0.1, 100 | head=0, 101 | softmax=1.0, 102 | tau=1.0, 103 | cluster_ratio=1.0, 104 | temporal_segment_ratio=1.0, 105 | **kwargs, 106 | ): 107 | self.ignore_index = ignore_index 108 | self.image_token_index = image_token_index 109 | self.projector_hidden_act = projector_hidden_act 110 | self.vision_feature_select_strategy = vision_feature_select_strategy 111 | self.vision_feature_layer = vision_feature_layer 112 | self.vocab_size = vocab_size 113 | self.use_pooling = use_pooling 114 | self.gradient_checkpointing = gradient_checkpointing 115 | self.selected_layer = selected_layer 116 | self.alpha = alpha 117 | self.head = head 118 | self.softmax = softmax 119 | self.tau = tau 120 | self.cluster_ratio = cluster_ratio 121 | self.temporal_segment_ratio = temporal_segment_ratio 122 | 123 | self.vision_config = vision_config 124 | 125 | self.pooling_method = pooling_method # should be in 'max', 'avg' 126 | self.pooling_shape = pooling_shape # 127 | self.frame_shape = frame_shape # 128 | self.num_frames = num_frames 129 | if isinstance(self.vision_config, dict): 130 | vision_config["model_type"] = ( 131 | vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" 132 | ) 133 | self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) 134 | elif vision_config is None: 135 | self.vision_config = CONFIG_MAPPING["clip_vision_model"]( 136 | intermediate_size=4096, 137 | hidden_size=1024, 138 | patch_size=14, 139 | image_size=336, 140 | num_hidden_layers=24, 141 | num_attention_heads=16, 142 | vocab_size=32000, 143 | projection_dim=768, 144 | ) 145 | self.vocab_size = self.vocab_size 146 | 147 | self.text_config = text_config 148 | 149 | if isinstance(self.text_config, dict): 150 | text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" 151 | self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) 152 | self.vocab_size = self.text_config.vocab_size 153 | self.text_config.gradient_checkpointing = self.gradient_checkpointing 154 | 155 | elif text_config is None: 156 | tmp_config = {"_attn_implementation":"flash_attention_2", 157 | "gradient_checkpointing": self.gradient_checkpointing} 158 | self.text_config = CONFIG_MAPPING["llama"](**tmp_config) 159 | self.text_config.gradient_checkpointing = self.gradient_checkpointing 160 | # self.text_config["_attn_implementation"]="flash_attention_2" # xl: temporal hard code 161 | 162 | 163 | super().__init__(**kwargs) 164 | -------------------------------------------------------------------------------- /models/pllava/convert_pllava_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Not yet -------------------------------------------------------------------------------- /models/pllava/modeling_flash_attention_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import inspect 17 | import os 18 | from typing import Optional, Tuple 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | from transformers.utils import is_flash_attn_2_available 24 | 25 | if is_flash_attn_2_available(): 26 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 27 | from flash_attn import flash_attn_func, flash_attn_varlen_func 28 | 29 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 30 | 31 | 32 | def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: 33 | """ 34 | Retrieves indexing data required to repad unpadded (ragged) tensors. 35 | 36 | Arguments: 37 | attention_mask (`torch.Tensor`): 38 | Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. 39 | 40 | Return: 41 | indices (`torch.Tensor): 42 | The indices of non-masked tokens from the flattened input sequence. 43 | cu_seqlens (`torch.Tensor`): 44 | The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). 45 | max_seqlen_in_batch (`int`): 46 | Maximum sequence length in batch. 47 | """ 48 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 49 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 50 | max_seqlen_in_batch = seqlens_in_batch.max().item() 51 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 52 | return ( 53 | indices, 54 | cu_seqlens, 55 | max_seqlen_in_batch, 56 | ) 57 | 58 | 59 | def _upad_input( 60 | query_layer: torch.Tensor, 61 | key_layer: torch.Tensor, 62 | value_layer: torch.Tensor, 63 | attention_mask: torch.Tensor, 64 | query_length: int, 65 | ): 66 | """ 67 | Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. 68 | 69 | This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary 70 | tensors for query, key, value tensors. 71 | 72 | Arguments: 73 | query_layer (`torch.Tensor`): 74 | Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). 75 | key_layer (`torch.Tensor`): 76 | Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). 77 | value_layer (`torch.Tensor`): 78 | Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). 79 | attention_mask (`torch.Tensor`): 80 | Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. 81 | query_length (`int`): 82 | Target length. 83 | 84 | Return: 85 | query_layer (`torch.Tensor): 86 | Query state without padding. Shape: (total_target_length, num_heads, head_dim). 87 | key_layer (`torch.Tensor`): 88 | Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). 89 | value_layer (`torch.Tensor`): 90 | Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). 91 | indices_q (`torch.Tensor`): 92 | The indices of non-masked tokens from the flattened input target sequence. 93 | (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): 94 | The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). 95 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): 96 | Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). 97 | """ 98 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 99 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 100 | 101 | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) 102 | value_layer = index_first_axis( 103 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 104 | ) 105 | if query_length == kv_seq_len: 106 | query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) 107 | cu_seqlens_q = cu_seqlens_k 108 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 109 | indices_q = indices_k 110 | elif query_length == 1: 111 | max_seqlen_in_batch_q = 1 112 | cu_seqlens_q = torch.arange( 113 | batch_size + 1, dtype=torch.int32, device=query_layer.device 114 | ) # There is a memcpy here, that is very bad. 115 | indices_q = cu_seqlens_q[:-1] 116 | query_layer = query_layer.squeeze(1) 117 | else: 118 | # The -q_len: slice assumes left padding. 119 | attention_mask = attention_mask[:, -query_length:] 120 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 121 | 122 | return ( 123 | query_layer, 124 | key_layer, 125 | value_layer, 126 | indices_q, 127 | (cu_seqlens_q, cu_seqlens_k), 128 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 129 | ) 130 | 131 | 132 | def _flash_attention_forward( 133 | query_states: torch.Tensor, 134 | key_states: torch.Tensor, 135 | value_states: torch.Tensor, 136 | attention_mask: torch.Tensor, 137 | query_length: int, 138 | is_causal: bool, 139 | dropout: float = 0.0, 140 | softmax_scale: Optional[float] = None, 141 | sliding_window: Optional[int] = None, 142 | use_top_left_mask: bool = False, 143 | softcap: Optional[float] = None, 144 | deterministic: bool = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1", 145 | ): 146 | """ 147 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 148 | first unpad the input, then computes the attention scores and pad the final attention scores. 149 | 150 | Args: 151 | query_states (`torch.Tensor`): 152 | Input query states to be passed to Flash Attention API 153 | key_states (`torch.Tensor`): 154 | Input key states to be passed to Flash Attention API 155 | value_states (`torch.Tensor`): 156 | Input value states to be passed to Flash Attention API 157 | attention_mask (`torch.Tensor`): 158 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 159 | position of padding tokens and 1 for the position of non-padding tokens. 160 | dropout (`float`): 161 | Attention dropout 162 | softmax_scale (`float`, *optional*): 163 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 164 | use_top_left_mask (`bool`, defaults to `False`): 165 | flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. 166 | softcap (`float`, *optional*): 167 | Softcap for the attention logits, used e.g. in gemma2. 168 | deterministic (`bool`, *optional*): 169 | Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. 170 | """ 171 | if not use_top_left_mask: 172 | causal = is_causal 173 | else: 174 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. 175 | causal = is_causal and query_length != 1 176 | 177 | # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). 178 | use_sliding_windows = ( 179 | _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window 180 | ) 181 | flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} 182 | 183 | if is_flash_attn_greater_or_equal("2.4.1"): 184 | flash_kwargs["deterministic"] = deterministic 185 | 186 | if softcap is not None: 187 | flash_kwargs["softcap"] = softcap 188 | 189 | # Contains at least one padding token in the sequence 190 | if attention_mask is not None: 191 | batch_size = query_states.shape[0] 192 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( 193 | query_states, key_states, value_states, attention_mask, query_length 194 | ) 195 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 196 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 197 | 198 | attn_output_unpad = flash_attn_varlen_func( 199 | query_states, 200 | key_states, 201 | value_states, 202 | cu_seqlens_q=cu_seqlens_q, 203 | cu_seqlens_k=cu_seqlens_k, 204 | max_seqlen_q=max_seqlen_in_batch_q, 205 | max_seqlen_k=max_seqlen_in_batch_k, 206 | dropout_p=dropout, 207 | softmax_scale=softmax_scale, 208 | causal=causal, 209 | **flash_kwargs, 210 | ) 211 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 212 | else: 213 | attn_output = flash_attn_func( 214 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs 215 | ) 216 | 217 | return attn_output -------------------------------------------------------------------------------- /requirements.no_torch.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.26.1 3 | addict==2.4.0 4 | aiofiles==23.2.1 5 | aliyun-python-sdk-core==2.15.0 6 | aliyun-python-sdk-kms==2.16.2 7 | altair==5.2.0 8 | annotated-types==0.6.0 9 | antlr4-python3-runtime==4.9.3 10 | anyio==4.3.0 11 | anykeystore==0.2 12 | apex==0.9.10.dev0 13 | appdirs==1.4.4 14 | argcomplete==3.2.3 15 | attrs==23.2.0 16 | av==10.0.0 17 | beautifulsoup4==4.12.3 18 | blessed==1.20.0 19 | blessings==1.7 20 | boto3==1.34.63 21 | botocore==1.34.63 22 | Brotli==1.1.0 23 | cachetools==5.3.3 24 | certifi==2024.2.2 25 | cffi==1.16.0 26 | charset-normalizer==3.3.2 27 | click==8.1.7 28 | colorama==0.4.6 29 | contourpy==1.2.0 30 | crcmod==1.7 31 | cryptacular==1.6.2 32 | cryptography==42.0.5 33 | cycler==0.12.1 34 | dacite==1.7.0 35 | decorator==4.4.2 36 | decord==0.6.0 37 | deepspeed==0.14.0 38 | defusedxml==0.7.1 39 | Deprecated==1.2.14 40 | dill==0.3.8 41 | distro==1.9.0 42 | dnspython==2.6.1 43 | docker-pycreds==0.4.0 44 | einops==0.6.1 45 | exceptiongroup==1.2.0 46 | fastapi==0.110.0 47 | ffmpeg==1.4 48 | ffmpy==0.3.2 49 | fiftyone==0.23.6 50 | fiftyone-brain==0.16.1 51 | fiftyone_db==1.1.2 52 | filelock==3.9.0 53 | flash-attn==2.5.6 54 | fonttools==4.49.0 55 | fsspec==2024.2.0 56 | ftfy==6.1.3 57 | future==1.0.0 58 | fvcore==0.1.5.post20221221 59 | gdown==5.1.0 60 | gitdb==4.0.11 61 | GitPython==3.1.42 62 | glob2==0.7 63 | google-auth==2.28.2 64 | google-auth-oauthlib==1.2.0 65 | gpustat==1.1.1 66 | gradio==4.21.0 67 | gradio_client==0.12.0 68 | graphql-core==3.2.3 69 | greenlet==3.0.3 70 | grpcio==1.62.1 71 | h11==0.14.0 72 | h2==4.1.0 73 | hjson==3.1.0 74 | hpack==4.0.0 75 | httpcore==1.0.4 76 | httpx==0.27.0 77 | huggingface-hub==0.21.4 78 | humanize==4.9.0 79 | hupper==1.12.1 80 | Hypercorn==0.16.0 81 | hyperframe==6.0.1 82 | idna==3.6 83 | idscheck==2.3.0 84 | imageio==2.27.0 85 | imageio-ffmpeg==0.4.9 86 | importlib_metadata==7.0.2 87 | importlib_resources==6.3.0 88 | inflate64==1.0.0 89 | iopath==0.1.10 90 | Jinja2==3.1.2 91 | jmespath==0.10.0 92 | joblib==1.3.2 93 | jsonlines==4.0.0 94 | jsonschema==4.21.1 95 | jsonschema-specifications==2023.12.1 96 | kaleido==0.2.1 97 | kiwisolver==1.4.5 98 | lazy_loader==0.3 99 | Markdown==3.6 100 | markdown-it-py==3.0.0 101 | MarkupSafe==2.1.3 102 | matplotlib==3.8.3 103 | mdurl==0.1.2 104 | mmcv-full==1.7.2 105 | model-index==0.1.11 106 | mongoengine==0.24.2 107 | motor==3.3.2 108 | moviepy==1.0.3 109 | mpmath==1.3.0 110 | multivolumefile==0.2.3 111 | networkx==3.2.1 112 | ninja==1.11.1.1 113 | numpy 114 | oauthlib==3.2.2 115 | omegaconf==2.3.0 116 | openai==1.14.0 117 | opencv-python==4.9.0.80 118 | opencv-python-headless==4.9.0.80 119 | opendatalab==0.0.10 120 | openmim==0.3.9 121 | openxlab==0.0.36 122 | ordered-set==4.1.0 123 | orjson==3.9.15 124 | oss2==2.17.0 125 | packaging==24.0 126 | pandas==1.5.3 127 | PasteDeploy==3.1.0 128 | pathtools==0.1.2 129 | pbkdf2==1.3 130 | peft==0.10.0 131 | pillow==10.2.0 132 | plaster==1.1.2 133 | plaster-pastedeploy==1.0.1 134 | platformdirs==4.2.0 135 | plotly==5.20.0 136 | portalocker==2.8.2 137 | pprintpp==0.4.0 138 | priority==2.0.0 139 | proglog==0.1.10 140 | protobuf==4.23.4 141 | psutil==5.9.4 142 | py-cpuinfo==9.0.0 143 | py7zr==0.21.0 144 | pyasn1==0.5.1 145 | pyasn1-modules==0.3.0 146 | pybcj==1.0.2 147 | pycparser==2.21 148 | pycryptodome==3.20.0 149 | pycryptodomex==3.20.0 150 | pydantic==2.6.4 151 | pydantic_core==2.16.3 152 | pydub==0.25.1 153 | Pygments==2.17.2 154 | pymongo==4.6.2 155 | pynvml==11.5.0 156 | pyparsing==3.1.2 157 | pyppmd==1.1.0 158 | pyramid==2.0.2 159 | pyramid-mailer==0.15.1 160 | PySocks==1.7.1 161 | python-dateutil==2.9.0.post0 162 | python-multipart==0.0.9 163 | python3-openid==3.2.0 164 | pytz==2023.4 165 | PyYAML==6.0 166 | pyzstd==0.15.9 167 | rarfile==4.1 168 | referencing==0.33.0 169 | regex==2023.12.25 170 | repoze.sendmail==4.4.1 171 | requests==2.28.2 172 | requests-oauthlib==1.4.0 173 | retrying==1.3.4 174 | rich==13.4.2 175 | rpds-py==0.18.0 176 | rsa==4.9 177 | ruff==0.3.2 178 | s3transfer==0.10.1 179 | safetensors==0.4.2 180 | scikit-image==0.22.0 181 | scikit-learn==1.4.1.post1 182 | scipy==1.10.1 183 | semantic-version==2.10.0 184 | sentencepiece==0.2.0 185 | sentry-sdk==1.42.0 186 | setproctitle==1.3.3 187 | shellingham==1.5.4 188 | six==1.16.0 189 | smmap==5.0.1 190 | sniffio==1.3.1 191 | sortedcontainers==2.4.0 192 | soupsieve==2.5 193 | SQLAlchemy==2.0.28 194 | sse-starlette==0.10.3 195 | sseclient-py==1.8.0 196 | starlette==0.36.3 197 | strawberry-graphql==0.138.1 198 | sympy==1.12 199 | tabulate==0.9.0 200 | taskgroup==0.0.0a4 201 | tenacity==8.2.3 202 | tensorboard==2.15.1 203 | tensorboard-data-server==0.7.2 204 | tensorboardX==2.6.2.2 205 | termcolor==2.3.0 206 | texttable==1.7.0 207 | threadpoolctl==3.3.0 208 | tifffile==2024.2.12 209 | timm==0.6.12 210 | tokenizers==0.15.2 211 | tomli==2.0.1 212 | tomlkit==0.12.0 213 | toolz==0.12.1 214 | tqdm==4.65.2 215 | transaction==4.0 216 | transformers==4.37.1 217 | translationstring==1.4 218 | triton==2.2.0 219 | typer==0.9.0 220 | typing_extensions==4.8.0 221 | tzdata==2024.1 222 | tzlocal==5.2 223 | universal-analytics-python3==1.1.1 224 | urllib3==1.26.18 225 | uvicorn==0.28.0 226 | velruse==1.1.1 227 | venusian==3.1.0 228 | voxel51-eta==0.12.6 229 | wandb==0.14.0 230 | wcwidth==0.2.13 231 | WebOb==1.8.7 232 | websockets==11.0.3 233 | Werkzeug==3.0.1 234 | wrapt==1.16.0 235 | wsproto==1.2.0 236 | WTForms==3.1.2 237 | wtforms-recaptcha==0.3.2 238 | xmltodict==0.13.0 239 | yacs==0.1.8 240 | yapf==0.40.2 241 | zipp==3.18.1 242 | zope.deprecation==5.0 243 | zope.interface==6.2 244 | zope.sqlalchemy==3.1 245 | -------------------------------------------------------------------------------- /requirements.torch.txt: -------------------------------------------------------------------------------- 1 | --index-url https://download.pytorch.org/whl/cu118 2 | torch==2.2.1 3 | torchaudio==2.2.1 4 | torchvision==0.17.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.26.1 3 | addict==2.4.0 4 | aiofiles==23.2.1 5 | aliyun-python-sdk-core==2.15.0 6 | aliyun-python-sdk-kms==2.16.2 7 | altair==5.2.0 8 | annotated-types==0.6.0 9 | antlr4-python3-runtime==4.9.3 10 | anyio==4.3.0 11 | anykeystore==0.2 12 | apex==0.9.10.dev0 13 | appdirs==1.4.4 14 | argcomplete==3.2.3 15 | attrs==23.2.0 16 | av==10.0.0 17 | beautifulsoup4==4.12.3 18 | blessed==1.20.0 19 | blessings==1.7 20 | boto3==1.34.63 21 | botocore==1.34.63 22 | Brotli==1.1.0 23 | cachetools==5.3.3 24 | certifi==2024.2.2 25 | cffi==1.16.0 26 | charset-normalizer==3.3.2 27 | click==8.1.7 28 | colorama==0.4.6 29 | contourpy==1.2.0 30 | crcmod==1.7 31 | cryptacular==1.6.2 32 | cryptography==42.0.5 33 | cycler==0.12.1 34 | dacite==1.7.0 35 | decorator==4.4.2 36 | decord==0.6.0 37 | deepspeed==0.14.0 38 | defusedxml==0.7.1 39 | Deprecated==1.2.14 40 | dill==0.3.8 41 | distro==1.9.0 42 | dnspython==2.6.1 43 | docker-pycreds==0.4.0 44 | einops==0.6.1 45 | exceptiongroup==1.2.0 46 | fastapi==0.110.0 47 | ffmpeg==1.4 48 | ffmpy==0.3.2 49 | fiftyone==0.23.6 50 | fiftyone-brain==0.16.1 51 | fiftyone_db==1.1.2 52 | filelock==3.9.0 53 | https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.6/flash_attn-2.5.6+cu118torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 54 | fonttools==4.49.0 55 | fsspec==2024.2.0 56 | ftfy==6.1.3 57 | future==1.0.0 58 | fvcore==0.1.5.post20221221 59 | gdown==5.1.0 60 | gitdb==4.0.11 61 | GitPython==3.1.42 62 | glob2==0.7 63 | google-auth==2.28.2 64 | google-auth-oauthlib==1.2.0 65 | gpustat==1.1.1 66 | gradio==4.21.0 67 | gradio_client==0.12.0 68 | graphql-core==3.2.3 69 | greenlet==3.0.3 70 | grpcio==1.62.1 71 | h11==0.14.0 72 | h2==4.1.0 73 | hjson==3.1.0 74 | hpack==4.0.0 75 | httpcore==1.0.4 76 | httpx==0.27.0 77 | huggingface-hub==0.21.4 78 | humanize==4.9.0 79 | hupper==1.12.1 80 | Hypercorn==0.16.0 81 | hyperframe==6.0.1 82 | idna==3.6 83 | idscheck==2.3.0 84 | imageio==2.27.0 85 | imageio-ffmpeg==0.4.9 86 | importlib_metadata==7.0.2 87 | importlib_resources==6.3.0 88 | inflate64==1.0.0 89 | iopath==0.1.10 90 | Jinja2==3.1.2 91 | jmespath==0.10.0 92 | joblib==1.3.2 93 | jsonlines==4.0.0 94 | jsonschema==4.21.1 95 | jsonschema-specifications==2023.12.1 96 | kaleido==0.2.1 97 | kiwisolver==1.4.5 98 | lazy_loader==0.3 99 | Markdown==3.6 100 | markdown-it-py==3.0.0 101 | MarkupSafe==2.1.3 102 | matplotlib==3.8.3 103 | mdurl==0.1.2 104 | mmcv-full==1.7.2 105 | model-index==0.1.11 106 | mongoengine==0.24.2 107 | motor==3.3.2 108 | moviepy==1.0.3 109 | mpmath==1.3.0 110 | multivolumefile==0.2.3 111 | networkx==3.2.1 112 | ninja==1.11.1.1 113 | numpy==1.23.5 114 | nvidia-cublas-cu11==11.11.3.6 115 | nvidia-cuda-cupti-cu11==11.8.87 116 | nvidia-cuda-nvrtc-cu11==11.8.89 117 | nvidia-cuda-runtime-cu11==11.8.89 118 | nvidia-cudnn-cu11==8.7.0.84 119 | nvidia-cufft-cu11==10.9.0.58 120 | nvidia-curand-cu11==10.3.0.86 121 | nvidia-cusolver-cu11==11.4.1.48 122 | nvidia-cusparse-cu11==11.7.5.86 123 | nvidia-ml-py==12.535.133 124 | nvidia-ml-py3==7.352.0 125 | nvidia-nccl-cu11==2.19.3 126 | nvidia-nvtx-cu11==11.8.86 127 | oauthlib==3.2.2 128 | omegaconf==2.3.0 129 | openai==1.14.0 130 | opencv-python==4.9.0.80 131 | opencv-python-headless==4.9.0.80 132 | opendatalab==0.0.10 133 | openmim==0.3.9 134 | openxlab==0.0.36 135 | ordered-set==4.1.0 136 | orjson==3.9.15 137 | oss2==2.17.0 138 | packaging==24.0 139 | pandas==1.5.3 140 | PasteDeploy==3.1.0 141 | pathtools==0.1.2 142 | pbkdf2==1.3 143 | peft==0.10.0 144 | pillow==10.2.0 145 | plaster==1.1.2 146 | plaster-pastedeploy==1.0.1 147 | platformdirs==4.2.0 148 | plotly==5.20.0 149 | portalocker==2.8.2 150 | pprintpp==0.4.0 151 | priority==2.0.0 152 | proglog==0.1.10 153 | protobuf==4.23.4 154 | psutil==5.9.4 155 | py-cpuinfo==9.0.0 156 | py7zr==0.21.0 157 | pyasn1==0.5.1 158 | pyasn1-modules==0.3.0 159 | pybcj==1.0.2 160 | pycparser==2.21 161 | pycryptodome==3.20.0 162 | pycryptodomex==3.20.0 163 | pydantic==2.6.4 164 | pydantic_core==2.16.3 165 | pydub==0.25.1 166 | Pygments==2.17.2 167 | pymongo==4.6.2 168 | pynvml==11.5.0 169 | pyparsing==3.1.2 170 | pyppmd==1.1.0 171 | pyramid==2.0.2 172 | pyramid-mailer==0.15.1 173 | PySocks==1.7.1 174 | python-dateutil==2.9.0.post0 175 | python-multipart==0.0.9 176 | python3-openid==3.2.0 177 | pytz==2023.4 178 | PyYAML==6.0 179 | pyzstd==0.15.9 180 | rarfile==4.1 181 | referencing==0.33.0 182 | regex==2023.12.25 183 | repoze.sendmail==4.4.1 184 | requests==2.28.2 185 | requests-oauthlib==1.4.0 186 | retrying==1.3.4 187 | rich==13.4.2 188 | rpds-py==0.18.0 189 | rsa==4.9 190 | ruff==0.3.2 191 | s3transfer==0.10.1 192 | safetensors==0.4.2 193 | scikit-image==0.22.0 194 | scikit-learn==1.4.1.post1 195 | scipy==1.10.1 196 | semantic-version==2.10.0 197 | sentencepiece==0.2.0 198 | sentry-sdk==1.42.0 199 | setproctitle==1.3.3 200 | shellingham==1.5.4 201 | six==1.16.0 202 | smmap==5.0.1 203 | sniffio==1.3.1 204 | sortedcontainers==2.4.0 205 | soupsieve==2.5 206 | SQLAlchemy==2.0.28 207 | sse-starlette==0.10.3 208 | sseclient-py==1.8.0 209 | starlette==0.36.3 210 | strawberry-graphql==0.138.1 211 | sympy==1.12 212 | tabulate==0.9.0 213 | taskgroup==0.0.0a4 214 | tenacity==8.2.3 215 | tensorboard==2.15.1 216 | tensorboard-data-server==0.7.2 217 | tensorboardX==2.6.2.2 218 | termcolor==2.3.0 219 | texttable==1.7.0 220 | threadpoolctl==3.3.0 221 | tifffile==2024.2.12 222 | timm==0.6.12 223 | tokenizers==0.15.2 224 | tomli==2.0.1 225 | tomlkit==0.12.0 226 | toolz==0.12.1 227 | torch==2.2.1 228 | torchaudio==2.2.1 229 | torchvision==0.17.1 230 | tqdm==4.65.2 231 | transaction==4.0 232 | transformers==4.37.1 233 | translationstring==1.4 234 | triton==2.2.0 235 | typer==0.9.0 236 | typing_extensions==4.8.0 237 | tzdata==2024.1 238 | tzlocal==5.2 239 | universal-analytics-python3==1.1.1 240 | urllib3==1.26.18 241 | uvicorn==0.28.0 242 | velruse==1.1.1 243 | venusian==3.1.0 244 | voxel51-eta==0.12.6 245 | wandb==0.14.0 246 | wcwidth==0.2.13 247 | WebOb==1.8.7 248 | websockets==11.0.3 249 | Werkzeug==3.0.1 250 | wrapt==1.16.0 251 | wsproto==1.2.0 252 | WTForms==3.1.2 253 | wtforms-recaptcha==0.3.2 254 | xmltodict==0.13.0 255 | yacs==0.1.8 256 | yapf==0.40.2 257 | zipp==3.18.1 258 | zope.deprecation==5.0 259 | zope.interface==6.2 260 | zope.sqlalchemy==3.1 261 | -------------------------------------------------------------------------------- /scripts/accel_config_deepspeed_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 8 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 4 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /scripts/accel_config_deepspeed_zero3_offload.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 2 5 | offload_optimizer_device: cpu 6 | offload_param_device: cpu 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /scripts/accel_config_deepspeed_zero3_offload_multinode.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 2 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 0 14 | main_process_ip: fdbd:dc61:18:8::20 15 | main_process_port: 6876 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 2 19 | num_processes: 16 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /scripts/accel_config_deepspeed_zero3_offload_multinode_1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 2 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 0 14 | main_process_ip: fdbd:dc61:18:8::20 15 | main_process_port: 6876 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 2 19 | num_processes: 16 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /scripts/accel_config_deepspeed_zero3_offload_multinode_2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 2 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 1 14 | main_process_ip: fdbd:dc61:18:8::20 15 | main_process_port: 6876 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 2 19 | num_processes: 16 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /scripts/accel_config_deepspeed_zero3_offload_singlegpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 16 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: cpu 7 | offload_param_device: cpu 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 1 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /scripts/accel_config_multigpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: 0,1,2,3,4,5,6,7 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /scripts/accel_config_multinode.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 1 7 | main_process_ip: 10.193.16.150 8 | main_process_port: 6784 9 | main_training_function: main 10 | mixed_precision: bf16 11 | num_machines: 2 12 | num_processes: 16 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /scripts/accel_config_singlegpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | gpu_ids: '0' 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | export OPENAI_API_KEY=YOUR_API_KEY 2 | num_frames=16 3 | test_ratio=1 4 | 5 | model_dir=MODELS/pllava-7b 6 | weight_dir=MODELS/pllava-7b 7 | 8 | lora_alpha=14 9 | selected_layers=(10) 10 | alphas=(0.4) 11 | taus=(0.8) 12 | temporal_segment_ratios=(0.25) 13 | cluster_ratios=(0.5) 14 | 15 | for alpha in "${alphas[@]}"; do 16 | for selected_layer in "${selected_layers[@]}"; do 17 | for tau in "${taus[@]}"; do 18 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do 19 | for cluster_ratio in "${cluster_ratios[@]}"; do 20 | # 执行命令 21 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio} 22 | mkdir -p "${SAVE_DIR}" 23 | conv_mode=eval_mvbench 24 | python -m tasks.eval.mvbench.pllava_eval_mvbench \ 25 | --pretrained_model_name_or_path ${model_dir} \ 26 | --save_path ${SAVE_DIR}/mvbench \ 27 | --num_frames ${num_frames} \ 28 | --use_lora \ 29 | --lora_alpha ${lora_alpha} \ 30 | --top_p 1.0 \ 31 | --temperature 1.0 \ 32 | --weight_dir ${weight_dir} \ 33 | --pooling_shape 16-12-12 \ 34 | --conv_mode ${conv_mode} \ 35 | --selected_layer ${selected_layer} \ 36 | --alpha ${alpha} \ 37 | --tau ${tau} \ 38 | --temporal_segment_ratio ${temporal_segment_ratio} \ 39 | --cluster_ratio ${cluster_ratio} 40 | done 41 | done 42 | done 43 | done 44 | done 45 | 46 | lora_alpha=14 47 | selected_layers=(10) 48 | alphas=(0.4) 49 | taus=(0.8) 50 | temporal_segment_ratios=(0.25) 51 | cluster_ratios=(0.5) 52 | 53 | for alpha in "${alphas[@]}"; do 54 | for selected_layer in "${selected_layers[@]}"; do 55 | for tau in "${taus[@]}"; do 56 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do 57 | for cluster_ratio in "${cluster_ratios[@]}"; do 58 | # 执行命令 59 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio} 60 | mkdir -p "${SAVE_DIR}" 61 | conv_mode=eval_videomme 62 | python -m tasks.eval.videomme.pllava_eval_videomme \ 63 | --pretrained_model_name_or_path ${model_dir} \ 64 | --save_path ${SAVE_DIR}/videomme \ 65 | --num_frames ${num_frames} \ 66 | --use_lora \ 67 | --lora_alpha ${lora_alpha} \ 68 | --top_p 1.0 \ 69 | --temperature 1.0 \ 70 | --weight_dir ${weight_dir} \ 71 | --pooling_shape 16-12-12 \ 72 | --conv_mode ${conv_mode} \ 73 | --selected_layer ${selected_layer} \ 74 | --alpha ${alpha} \ 75 | --tau ${tau} \ 76 | --temporal_segment_ratio ${temporal_segment_ratio} \ 77 | --cluster_ratio ${cluster_ratio} 78 | done 79 | done 80 | done 81 | done 82 | done 83 | 84 | lora_alpha=14 85 | selected_layers=(10) 86 | alphas=(0.4) 87 | taus=(0.8) 88 | temporal_segment_ratios=(0.25) 89 | cluster_ratios=(0.5) 90 | 91 | for alpha in "${alphas[@]}"; do 92 | for selected_layer in "${selected_layers[@]}"; do 93 | for tau in "${taus[@]}"; do 94 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do 95 | for cluster_ratio in "${cluster_ratios[@]}"; do 96 | # 执行命令 97 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio} 98 | mkdir -p "${SAVE_DIR}" 99 | conv_mode=eval_mvbench 100 | python -m tasks.eval.egoshcema.pllava_eval_egoschema \ 101 | --pretrained_model_name_or_path ${model_dir} \ 102 | --save_path ${SAVE_DIR}/egoschema \ 103 | --num_frames ${num_frames} \ 104 | --use_lora \ 105 | --lora_alpha ${lora_alpha} \ 106 | --top_p 1.0 \ 107 | --temperature 1.0 \ 108 | --weight_dir ${weight_dir} \ 109 | --pooling_shape 16-12-12 \ 110 | --conv_mode ${conv_mode} \ 111 | --selected_layer ${selected_layer} \ 112 | --alpha ${alpha} \ 113 | --tau ${tau} \ 114 | --temporal_segment_ratio ${temporal_segment_ratio} \ 115 | --cluster_ratio ${cluster_ratio} 116 | done 117 | done 118 | done 119 | done 120 | done 121 | 122 | 123 | lora_alpha=4 124 | selected_layers=(5) 125 | alphas=(0.4) 126 | taus=(0.8) 127 | temporal_segment_ratios=(0.25) 128 | cluster_ratios=(0.5) 129 | 130 | for alpha in "${alphas[@]}"; do 131 | for selected_layer in "${selected_layers[@]}"; do 132 | for tau in "${taus[@]}"; do 133 | for temporal_segment_ratio in "${temporal_segment_ratios[@]}"; do 134 | for cluster_ratio in "${cluster_ratios[@]}"; do 135 | # 执行命令 136 | SAVE_DIR=test_results/pllava-7b-lora${lora_alpha}-threshold${tau}-layer${selected_layer}-alpha${alpha}-temporal-segment-ratio-${temporal_segment_ratio}-cluster-ratio-${cluster_ratio} 137 | mkdir -p "${SAVE_DIR}" 138 | conv_mode=eval_vcgbench 139 | python -m tasks.eval.vcgbench.pllava_eval_vcgbench \ 140 | --pretrained_model_name_or_path ${model_dir} \ 141 | --save_path ${SAVE_DIR}/vcgbench \ 142 | --num_frames ${num_frames} \ 143 | --weight_dir ${weight_dir} \ 144 | --pooling_shape 16-12-12 \ 145 | --test_ratio ${test_ratio} \ 146 | --use_lora \ 147 | --lora_alpha ${lora_alpha} \ 148 | --selected_layer ${selected_layer} \ 149 | --alpha ${alpha} \ 150 | --tau ${tau} \ 151 | --temporal_segment_ratio ${temporal_segment_ratio} \ 152 | --cluster_ratio ${cluster_ratio} 153 | done 154 | done 155 | done 156 | done 157 | done -------------------------------------------------------------------------------- /tasks/eval/__pycache__/eval_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/__pycache__/eval_utils.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/__pycache__/eval_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/__pycache__/eval_utils.cpython-39.pyc -------------------------------------------------------------------------------- /tasks/eval/__pycache__/model_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/__pycache__/model_utils.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/demo/__init__.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from gradio.themes.utils import colors, fonts, sizes 3 | 4 | 5 | pllava_theme = gr.themes.Monochrome( 6 | text_size="sm", 7 | spacing_size="sm", 8 | primary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"), 9 | secondary_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"), 10 | neutral_hue=gr.themes.Color(c100="#f5f5f5", c200="#e5e5e5", c300="#d4d4d4", c400="#a3a3a3", c50="#fafafa", c500="#737373", c600="#525252", c700="#404040", c800="#262626", c900="#171717", c950="#000000"), 11 | ).set( 12 | background_fill_primary_dark='*primary_950', 13 | background_fill_secondary_dark='*neutral_950' 14 | ) 15 | 16 | -------------------------------------------------------------------------------- /tasks/eval/demo/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/demo/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/demo/__pycache__/pllava_demo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/demo/__pycache__/pllava_demo.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/demo/pllava_demo.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import copy 3 | import gradio as gr 4 | from gradio.themes.utils import colors, fonts, sizes 5 | 6 | from utils.easydict import EasyDict 7 | from tasks.eval.model_utils import load_pllava 8 | from tasks.eval.eval_utils import ( 9 | ChatPllava, 10 | conv_plain_v1, 11 | Conversation, 12 | conv_templates 13 | ) 14 | from tasks.eval.demo import pllava_theme 15 | 16 | SYSTEM="""You are Pllava, a large vision-language assistant. 17 | You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language. 18 | Follow the instructions carefully and explain your answers in detail based on the provided video. 19 | """ 20 | INIT_CONVERSATION: Conversation = conv_plain_v1.copy() 21 | 22 | 23 | # ======================================== 24 | # Model Initialization 25 | # ======================================== 26 | def init_model(args): 27 | 28 | print('Initializing PLLaVA') 29 | model, processor = load_pllava( 30 | args.pretrained_model_name_or_path, args.num_frames, 31 | use_lora=args.use_lora, 32 | weight_dir=args.weight_dir, 33 | lora_alpha=args.lora_alpha, 34 | use_multi_gpus=args.use_multi_gpus) 35 | if not args.use_multi_gpus: 36 | model = model.to('cuda') 37 | chat = ChatPllava(model, processor) 38 | return chat 39 | 40 | 41 | # ======================================== 42 | # Gradio Setting 43 | # ======================================== 44 | def gradio_reset(chat_state, img_list): 45 | if chat_state is not None: 46 | chat_state = INIT_CONVERSATION.copy() 47 | if img_list is not None: 48 | img_list = [] 49 | return ( 50 | None, 51 | gr.update(value=None, interactive=True), 52 | gr.update(value=None, interactive=True), 53 | gr.update(placeholder='Please upload your video first', interactive=False), 54 | gr.update(value="Upload & Start Chat", interactive=True), 55 | chat_state, 56 | img_list 57 | ) 58 | 59 | 60 | def upload_img(gr_img, gr_video, chat_state=None, num_segments=None, img_list=None): 61 | print(gr_img, gr_video) 62 | chat_state = INIT_CONVERSATION.copy() if chat_state is None else chat_state 63 | img_list = [] if img_list is None else img_list 64 | 65 | if gr_img is None and gr_video is None: 66 | return None, None, gr.update(interactive=True),gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None 67 | if gr_video: 68 | llm_message, img_list, chat_state = chat.upload_video(gr_video, chat_state, img_list, num_segments) 69 | return ( 70 | gr.update(interactive=True), 71 | gr.update(interactive=True), 72 | gr.update(interactive=True, placeholder='Type and press Enter'), 73 | gr.update(value="Start Chatting", interactive=False), 74 | chat_state, 75 | img_list, 76 | ) 77 | if gr_img: 78 | llm_message, img_list,chat_state = chat.upload_img(gr_img, chat_state, img_list) 79 | return ( 80 | gr.update(interactive=True), 81 | gr.update(interactive=True), 82 | gr.update(interactive=True, placeholder='Type and press Enter'), 83 | gr.update(value="Start Chatting", interactive=False), 84 | chat_state, 85 | img_list 86 | ) 87 | 88 | 89 | def gradio_ask(user_message, chatbot, chat_state, system): 90 | if len(user_message) == 0: 91 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state 92 | chat_state = chat.ask(user_message, chat_state, system) 93 | chatbot = chatbot + [[user_message, None]] 94 | return '', chatbot, chat_state 95 | 96 | 97 | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): 98 | llm_message, llm_message_token, chat_state = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=200, num_beams=num_beams, temperature=temperature) 99 | llm_message = llm_message.replace("", "") # handle 100 | chatbot[-1][1] = llm_message 101 | print(chat_state) 102 | print(f"Answer: {llm_message}") 103 | return chatbot, chat_state, img_list 104 | 105 | 106 | def parse_args(): 107 | parser = ArgumentParser() 108 | parser.add_argument( 109 | "--pretrained_model_name_or_path", 110 | type=str, 111 | required=True, 112 | default='llava-hf/llava-1.5-7b-hf' 113 | ) 114 | parser.add_argument( 115 | "--num_frames", 116 | type=int, 117 | required=True, 118 | default=4, 119 | ) 120 | parser.add_argument( 121 | "--use_lora", 122 | action='store_true' 123 | ) 124 | parser.add_argument( 125 | "--use_multi_gpus", 126 | action='store_true' 127 | ) 128 | parser.add_argument( 129 | "--weight_dir", 130 | type=str, 131 | required=False, 132 | default=None, 133 | ) 134 | parser.add_argument( 135 | "--conv_mode", 136 | type=str, 137 | required=False, 138 | default=None, 139 | ) 140 | parser.add_argument( 141 | "--lora_alpha", 142 | type=int, 143 | required=False, 144 | default=None, 145 | ) 146 | parser.add_argument( 147 | "--server_port", 148 | type=int, 149 | required=False, 150 | default=7868, 151 | ) 152 | args = parser.parse_args() 153 | return args 154 | 155 | 156 | title = """

PLLAVA

""" 157 | description = ( 158 | """

159 | # PLLAVA! 160 |

161 | - Upload A Video 162 | - Press Upload 163 | - Start Chatting 164 | """ 165 | ) 166 | 167 | args = parse_args() 168 | 169 | model_description = f""" 170 | # MODEL INFO 171 | - pretrained_model_name_or_path:{args.pretrained_model_name_or_path} 172 | - use_lora:{args.use_lora} 173 | - weight_dir:{args.weight_dir} 174 | """ 175 | 176 | # with gr.Blocks(title="InternVideo-VideoChat!",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: 177 | with gr.Blocks(title="PLLaVA", 178 | theme=pllava_theme, 179 | css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo: 180 | gr.Markdown(title) 181 | gr.Markdown(description) 182 | gr.Markdown(model_description) 183 | with gr.Row(): 184 | with gr.Column(scale=0.5, visible=True) as video_upload: 185 | # with gr.Column(elem_id="image", scale=0.5) as img_part: 186 | with gr.Tab("Video", elem_id='video_tab'): 187 | up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360) 188 | with gr.Tab("Image", elem_id='image_tab'): 189 | up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload", height=360) 190 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") 191 | clear = gr.Button("Restart") 192 | 193 | # num_segments = gr.Slider( 194 | # minimum=8, 195 | # maximum=64, 196 | # value=8, 197 | # step=1, 198 | # interactive=True, 199 | # label="Video Segments", 200 | # ) 201 | 202 | with gr.Column(visible=True) as input_raws: 203 | system_string = gr.Textbox(SYSTEM, interactive=True, label='system') 204 | num_beams = gr.Slider( 205 | minimum=1, 206 | maximum=5, 207 | value=1, 208 | step=1, 209 | interactive=True, 210 | label="beam search numbers", 211 | ) 212 | temperature = gr.Slider( 213 | minimum=0.1, 214 | maximum=2.0, 215 | value=1.0, 216 | step=0.1, 217 | interactive=True, 218 | label="Temperature", 219 | ) 220 | 221 | chat_state = gr.State() 222 | img_list = gr.State() 223 | chatbot = gr.Chatbot(elem_id="chatbot",label='Conversation') 224 | with gr.Row(): 225 | with gr.Column(scale=0.7): 226 | text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False) 227 | with gr.Column(scale=0.15, min_width=0): 228 | run = gr.Button("💭Send") 229 | with gr.Column(scale=0.15, min_width=0): 230 | clear = gr.Button("🔄Clear") 231 | 232 | with gr.Row(): 233 | examples = gr.Examples( 234 | examples=[ 235 | ['example/jesse_dance.mp4', 'What is the man doing?'], 236 | ['example/yoga.mp4', 'What is the woman doing?'], 237 | ['example/cooking.mp4', 'Describe the background, characters and the actions in the provided video.'], 238 | # ['example/cooking.mp4', 'What is happening in the video?'], 239 | ['example/working.mp4', 'Describe the background, characters and the actions in the provided video.'], 240 | ['example/1917.mov', 'Describe the background, characters and the actions in the provided video.'], 241 | ], 242 | inputs=[up_video, text_input] 243 | ) 244 | 245 | 246 | chat = init_model(args) 247 | INIT_CONVERSATION = conv_templates[args.conv_mode] 248 | upload_button.click(upload_img, [up_image, up_video, chat_state], [up_image, up_video, text_input, upload_button, chat_state, img_list]) 249 | 250 | text_input.submit(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( 251 | gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] 252 | ) 253 | run.click(gradio_ask, [text_input, chatbot, chat_state, system_string], [text_input, chatbot, chat_state]).then( 254 | gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] 255 | ) 256 | run.click(lambda: "", None, text_input) 257 | clear.click(gradio_reset, [chat_state, img_list], [chatbot, up_image, up_video, text_input, upload_button, chat_state, img_list], queue=False) 258 | 259 | # demo.queue(max_size=5) 260 | demo.launch(share=True,server_port=args.server_port) 261 | # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True) 262 | -------------------------------------------------------------------------------- /tasks/eval/demo/show_compare.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import json 5 | import os 6 | import os.path as osp 7 | import gradio as gr 8 | import numpy as np 9 | 10 | from tasks.eval.recaption import load_results as load_results_recaption 11 | from tasks.eval.mvbench import load_results as load_results_mvbench 12 | from tasks.eval.vcgbench import load_results as load_results_vcgbench 13 | from tasks.eval.videoqabench import load_results as load_results_videoqabench 14 | from tasks.eval.demo import pllava_theme 15 | 16 | 17 | load_results_funcs = [ 18 | load_results_recaption, 19 | load_results_mvbench, 20 | load_results_vcgbench, 21 | load_results_videoqabench, 22 | ] 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | '--root_dir', 29 | required=True, 30 | ) 31 | args = parser.parse_args() 32 | return args 33 | 34 | args = parse_args() 35 | root_dir = args.root_dir 36 | 37 | def show(result_list_first, result_list_second, result_index): 38 | sample2index_second = {} 39 | 40 | for i, result in enumerate(result_list_second): 41 | if 'video_path' not in result: 42 | continue 43 | 44 | question = result['question'] if 'question' in result else '' 45 | video_path = result['video_path'] 46 | samplehash = question + '--' +video_path 47 | sample2index_second[samplehash] = i 48 | 49 | info = result_list_first[result_index] 50 | info_str_first = json.dumps(info, indent=4, ensure_ascii=False) 51 | video_path = info['video_path'] 52 | question = info['question'] if 'question' in info else '' 53 | samplehash = question + '--' +video_path 54 | if samplehash in sample2index_second: 55 | info = result_list_second[sample2index_second[samplehash]] 56 | info_str_second = json.dumps(info, indent=4, ensure_ascii=False) 57 | else: 58 | info_str_second = f"NO {video_path} IN THE SECOND RESULT DIR" 59 | return video_path, info_str_first, info_str_second 60 | 61 | def reload_results_dirs(): 62 | result_dirs = [] 63 | # load result dir paths 64 | for dirpath, dirnames, filenames in os.walk(args.root_dir): 65 | if len(dirnames) == 0 and len(filenames) != 0: 66 | result_dirs.append(dirpath) 67 | return gr.Dropdown(result_dirs, value=result_dirs[0]) 68 | 69 | def reload_results(result_dir): 70 | # if isinstance(result_dir, list): 71 | # result_dir = result_dir[0] 72 | 73 | if result_dir is None or not osp.exists(result_dir): 74 | return None 75 | 76 | for fn in load_results_funcs: 77 | result_list = fn(result_dir) 78 | if result_list is not None: 79 | np.random.shuffle(result_list) 80 | break 81 | result_index = gr.Slider(0, len(result_list), step=1) 82 | 83 | return result_list, result_index 84 | 85 | 86 | 87 | with gr.Blocks(title="PLLAVA RESULTS", theme=pllava_theme) as demo: 88 | result_list_first = gr.State() 89 | result_list_second = gr.State() 90 | 91 | with gr.Row(): 92 | with gr.Column(): 93 | gr.Markdown("# Showing off Model's Outputs.") 94 | gr.Markdown( 95 | "You can find all our results, including:\n" 96 | "1. results of Captioned Inter4k\n" 97 | "2. results of Different Benchmark inference outputs.\n" 98 | "Choose a directory to see the different output variant.\n" 99 | "You can also choose secondary directory (as long as they are from the same dataset.) to compare on the results.\n" 100 | ) 101 | 102 | with gr.Row(): 103 | with gr.Column(): 104 | show_video = gr.Video(interactive=False) 105 | 106 | with gr.Column(): 107 | button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory') 108 | result_index = gr.Slider(0, 0, step=1, label="Index") 109 | 110 | result_dir_first = gr.Dropdown(label='Test Result Path') 111 | info_first = gr.Text(interactive=False, label='Detailed Output Information') 112 | result_dir_second = gr.Dropdown(label='Test Result Path') 113 | info_second = gr.Text(interactive=False, label='Detailed Output Information') 114 | 115 | 116 | button_reload.click(reload_results_dirs, [], [result_dir_first]) 117 | button_reload.click(reload_results_dirs, [], [result_dir_second]) 118 | result_dir_first.change(reload_results, [result_dir_first], [result_list_first, result_index]) 119 | result_dir_second.change(reload_results, [result_dir_second], [result_list_second, result_index]) 120 | result_index.change(show, [result_list_first, result_list_second, result_index], [show_video, info_first, info_second]) 121 | demo.load(reload_results_dirs, [], [result_dir_first]) 122 | demo.load(reload_results_dirs, [], [result_dir_second]) 123 | 124 | demo.launch(share=True) -------------------------------------------------------------------------------- /tasks/eval/demo/show_gallery.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import json 5 | import os 6 | import os.path as osp 7 | import gradio as gr 8 | 9 | from tasks.eval.recaption import load_results as load_results_recaption 10 | from tasks.eval.mvbench import load_results as load_results_mvbench 11 | from tasks.eval.vcgbench import load_results as load_results_vcgbench 12 | from tasks.eval.videoqabench import load_results as load_results_videoqabench 13 | 14 | load_results_funcs = [ 15 | load_results_recaption, 16 | load_results_mvbench, 17 | load_results_vcgbench, 18 | load_results_videoqabench, 19 | ] 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | '--root_dir', 26 | required=True, 27 | ) 28 | args = parser.parse_args() 29 | return args 30 | 31 | args = parse_args() 32 | root_dir = args.root_dir 33 | 34 | def show(result_list, result_index): 35 | info = result_list[result_index] 36 | video_path = info['video_path'] 37 | info_str = json.dumps(info, indent=4) 38 | return video_path, info_str 39 | 40 | def reload_results_dirs(): 41 | result_dirs = [] 42 | # load result dir paths 43 | for dirpath, dirnames, filenames in os.walk(args.root_dir): 44 | if len(dirnames) == 0 and len(filenames) != 0: 45 | result_dirs.append(dirpath) 46 | return gr.Dropdown(result_dirs, value=result_dirs[0]) 47 | 48 | def reload_results(result_dir): 49 | # if isinstance(result_dir, list): 50 | # result_dir = result_dir[0] 51 | 52 | if result_dir is None or not osp.exists(result_dir): 53 | return None 54 | 55 | for fn in load_results_funcs: 56 | result_list = fn(result_dir) 57 | if result_list is not None: 58 | break 59 | 60 | result_index = gr.Slider(0, len(result_list), step=1) 61 | 62 | return result_list, result_index 63 | 64 | with gr.Blocks() as demo: 65 | result_list = gr.State() 66 | 67 | with gr.Row(): 68 | gr.Markdown("# Showing of what has came out.") 69 | 70 | with gr.Row(): 71 | with gr.Column(scale=1): 72 | gr.Markdown(f"### From Saved Results Directory {args.root_dir}") 73 | 74 | with gr.Column(scale=2): 75 | result_dir = gr.Dropdown(label='Test Result Path') 76 | button_reload = gr.Button(value='Reload From The Evaluation/Inference Root Directory') 77 | 78 | 79 | 80 | with gr.Row(): 81 | with gr.Column(): 82 | show_video = gr.Video(interactive=False) 83 | 84 | with gr.Column(): 85 | result_index = gr.Slider(0, 0, step=1, label="Index") 86 | info = gr.Text(interactive=False, label='Detailed Output Information') 87 | 88 | 89 | button_reload.click(reload_results_dirs, [], [result_dir]) 90 | result_dir.change(reload_results, [result_dir], [result_list, result_index]) 91 | result_index.change(show, [result_list, result_index], [show_video, info]) 92 | demo.load(reload_results_dirs, [], [result_dir]) 93 | 94 | demo.launch(share=True) -------------------------------------------------------------------------------- /tasks/eval/egoshcema/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tasks.eval.eval_utils import ( 5 | dump_json, 6 | load_json, 7 | EvalDataset, 8 | ) 9 | 10 | def check_ans(pred, gt): 11 | flag = False 12 | pred_list = pred.lower().split(' ') 13 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:]) 14 | gt_list = gt.lower().split(' ') 15 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:]) 16 | if gt_content[-1] == '.': 17 | gt_content = gt_content[:-1] 18 | 19 | if not any([c in pred_option for c in 'abcdefgABCDEFG']): 20 | print(f"model doesn't follow instructions: {pred}") 21 | elif pred_option.replace('.', '') in gt_option: 22 | flag = True 23 | elif gt_option in pred_option: 24 | flag = True 25 | 26 | return flag 27 | 28 | def save_results(result_list, save_path): 29 | final_res, acc_dict = {}, {} 30 | correct, total = 0, 0 31 | for res in result_list: 32 | task_type = res['task_type'] 33 | if task_type not in acc_dict: 34 | acc_dict[task_type] = [0, 0] # correct, total 35 | acc_dict[task_type][1] += 1 36 | total += 1 37 | pred = res['pred'] 38 | gt = res['gt'] 39 | if check_ans(pred=pred, gt=gt): 40 | acc_dict[task_type][0] += 1 41 | correct += 1 42 | 43 | for k, v in acc_dict.items(): 44 | final_res[k] = v[0] / v[1] * 100 45 | correct += v[0] 46 | total += v[1] 47 | final_res['Avg'] = correct / total * 100 48 | 49 | all_results = { 50 | "acc_dict": acc_dict, 51 | "result_list": result_list 52 | } 53 | dump_json(all_results, save_path, 'all_results.json') 54 | dump_json(final_res, save_path, 'upload_leaderboard.json') 55 | 56 | def load_results(save_path): 57 | all_results = load_json(save_path, 'all_results.json') 58 | if all_results is not None: 59 | result_list = all_results['result_list'] 60 | else: 61 | result_list = None 62 | # json_data = load_json(save_path, 'all_results.json')['result_list'] 63 | return result_list 64 | 65 | class EgoSchemaDataset(EvalDataset): 66 | data_list_info = { 67 | "FullSet": ("egoschema_fullset.json", "DATAS/ego_schema/videos", "video", False), # has start & end 68 | } 69 | data_dir = "DATAS/ego_schema/json" 70 | 71 | def __init__(self, *args, **kwargs): 72 | super().__init__(*args, **kwargs) 73 | 74 | data_list_info = self.data_list_info 75 | data_dir = self.data_dir 76 | 77 | self.data_list = [] 78 | for k, v in data_list_info.items(): 79 | with open(os.path.join(data_dir, v[0]), 'r') as f: 80 | json_data = json.load(f) 81 | for data in json_data: 82 | self.data_list.append({ 83 | 'task_type': k, 84 | 'prefix': v[1], 85 | 'data_type': v[2], 86 | 'bound': v[3], 87 | 'data': data 88 | }) 89 | # self.data_list = self.data_list[:100] # for debug 90 | self.decord_method = { 91 | 'video': self.read_video, 92 | 'gif': self.read_gif, 93 | 'frame': self.read_frame, 94 | 'npy': self.read_npy, 95 | } 96 | 97 | # # transform 98 | # crop_size = resolution 99 | # scale_size = resolution 100 | # input_mean = [0.48145466, 0.4578275, 0.40821073] 101 | # input_std = [0.26862954, 0.26130258, 0.27577711] 102 | # self.transform = T.Compose([ 103 | # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), 104 | # GroupCenterCrop(crop_size), 105 | # Stack(), 106 | # ToTorchFormatTensor(), 107 | # GroupNormalize(input_mean, input_std) 108 | # ]) 109 | 110 | def __getitem__(self, idx): 111 | question, answer = self.qa_template(self.data_list[idx]['data']) 112 | task_type = self.data_list[idx]['task_type'] 113 | decord_method = self.decord_method[self.data_list[idx]['data_type']] 114 | bound = None 115 | if self.data_list[idx]['bound']: 116 | bound = ( 117 | self.data_list[idx]['data']['start'], 118 | self.data_list[idx]['data']['end'], 119 | ) 120 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video']) 121 | 122 | 123 | # images_group = decord_method(video_path, bound) 124 | images_group = decord_method(video_path, bound) 125 | # try: # might be problem with decord 126 | # images_group = decord_method(video_path, bound) 127 | # except Exception as e: 128 | # print(f'error decoding {video_path}', e) 129 | # task_type = 'error_reading_video' 130 | # images_group = None 131 | 132 | return { 133 | 'video_path': video_path, 134 | 'video_pils': images_group, # some might use the original pils and do their own transforms 135 | 'question': question, 136 | 'answer': answer, 137 | 'task_type': task_type, 138 | } 139 | 140 | 141 | def qa_template(self, data): 142 | question = f"Question: {data['question']}\n" 143 | question += "Options:\n" 144 | answer = data['answer'] 145 | answer_idx = -1 146 | for idx, c in enumerate(data['candidates']): 147 | question += f"({chr(ord('A') + idx)}) {c}\n" 148 | if c == answer: 149 | answer_idx = idx 150 | question = question.rstrip() 151 | answer = f"({chr(ord('A') + answer_idx)}) {answer}" 152 | return question, answer 153 | 154 | -------------------------------------------------------------------------------- /tasks/eval/egoshcema/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/egoshcema/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/egoshcema/__pycache__/pllava_eval_egoschema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/egoshcema/__pycache__/pllava_eval_egoschema.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/mvbench/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tasks.eval.eval_utils import ( 5 | dump_json, 6 | load_json, 7 | EvalDataset, 8 | ) 9 | 10 | 11 | def check_ans(pred, gt): 12 | flag = False 13 | 14 | pred_list = pred.lower().split(' ') 15 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:]) 16 | gt_list = gt.lower().split(' ') 17 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:]) 18 | if gt_content[-1] == '.': 19 | gt_content = gt_content[:-1] 20 | 21 | if not any([c in pred_option for c in 'abcdefgABCDEFG']): 22 | print(f"model doesn't follow instructions: {pred}") 23 | elif pred_option.replace('.', '') in gt_option: 24 | flag = True 25 | elif gt_option in pred_option: 26 | flag = True 27 | 28 | return flag 29 | 30 | def save_results(result_list, save_path): 31 | 32 | final_res, acc_dict = {}, {} 33 | correct, total = 0, 0 34 | for res in result_list: 35 | task_type = res['task_type'] 36 | if task_type not in acc_dict: 37 | acc_dict[task_type] = [0, 0] # correct, total 38 | acc_dict[task_type][1] += 1 39 | total += 1 40 | pred = res['pred'] 41 | gt = res['gt'] 42 | if check_ans(pred=pred, gt=gt): 43 | acc_dict[task_type][0] += 1 44 | correct += 1 45 | 46 | for k, v in acc_dict.items(): 47 | final_res[k] = v[0] / v[1] * 100 48 | correct += v[0] 49 | total += v[1] 50 | final_res['Avg'] = correct / total * 100 51 | 52 | all_results = { 53 | "acc_dict": acc_dict, 54 | "result_list": result_list 55 | } 56 | dump_json(all_results, save_path, 'all_results.json') 57 | dump_json(final_res, save_path, 'upload_leaderboard.json') 58 | 59 | def load_results(save_path): 60 | all_results = load_json(save_path, 'all_results.json') 61 | if all_results is not None: 62 | result_list = all_results['result_list'] 63 | else: 64 | result_list = None 65 | # json_data = load_json(save_path, 'all_results.json')['result_list'] 66 | return result_list 67 | 68 | class MVBenchDataset(EvalDataset): 69 | data_list_info = { 70 | # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound") 71 | "Action Sequence": ("action_sequence.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end 72 | "Action Prediction": ("action_prediction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end 73 | "Action Antonym": ("action_antonym.json", "DATAS/MVBench/video/ssv2_video/", "video", False), 74 | "Fine-grained Action": ("fine_grained_action.json", "DATAS/MVBench/video/Moments_in_Time_Raw/videos/", "video", False), 75 | "Unexpected Action": ("unexpected_action.json", "DATAS/MVBench/video/FunQA_test/test/", "video", False), 76 | "Object Existence": ("object_existence.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), 77 | "Object Interaction": ("object_interaction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end 78 | "Object Shuffle": ("object_shuffle.json", "DATAS/MVBench/video/perception/videos/", "video", False), 79 | "Moving Direction": ("moving_direction.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), 80 | "Action Localization": ("action_localization.json", "DATAS/MVBench/video/sta/sta_video/", "video", True), # has start & end 81 | "Scene Transition": ("scene_transition.json", "DATAS/MVBench/video/scene_qa/video/", "video", False), 82 | "Action Count": ("action_count.json", "DATAS/MVBench/video/perception/videos/", "video", False), 83 | "Moving Count": ("moving_count.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), 84 | "Moving Attribute": ("moving_attribute.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), 85 | "State Change": ("state_change.json", "DATAS/MVBench/video/perception/videos/", "video", False), 86 | "Fine-grained Pose": ("fine_grained_pose.json", "DATAS/MVBench/video/nturgbd/", "video", False), 87 | "Character Order": ("character_order.json", "DATAS/MVBench/video/perception/videos/", "video", False), 88 | "Egocentric Navigation": ("egocentric_navigation.json", "DATAS/MVBench/video/vlnqa/", "video", False), 89 | "Episodic Reasoning": ("episodic_reasoning.json", "DATAS/MVBench/video/tvqa/frames_fps3_hq/", "frame", True), # has start & end, read frame 90 | "Counterfactual Inference": ("counterfactual_inference.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False), 91 | } 92 | data_dir = "DATAS/MVBench/json" 93 | 94 | def __init__(self, *args, **kwargs): 95 | super().__init__(*args, **kwargs) 96 | 97 | data_list_info = self.data_list_info 98 | data_dir = self.data_dir 99 | 100 | self.data_list = [] 101 | for k, v in data_list_info.items(): 102 | with open(os.path.join(data_dir, v[0]), 'r') as f: 103 | json_data = json.load(f) 104 | for data in json_data: 105 | self.data_list.append({ 106 | 'task_type': k, 107 | 'prefix': v[1], 108 | 'data_type': v[2], 109 | 'bound': v[3], 110 | 'data': data 111 | }) 112 | # self.data_list = self.data_list[:100] # for debug 113 | self.decord_method = { 114 | 'video': self.read_video, 115 | 'gif': self.read_gif, 116 | 'frame': self.read_frame, 117 | 'npy': self.read_npy, 118 | } 119 | 120 | # # transform 121 | # crop_size = resolution 122 | # scale_size = resolution 123 | # input_mean = [0.48145466, 0.4578275, 0.40821073] 124 | # input_std = [0.26862954, 0.26130258, 0.27577711] 125 | # self.transform = T.Compose([ 126 | # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), 127 | # GroupCenterCrop(crop_size), 128 | # Stack(), 129 | # ToTorchFormatTensor(), 130 | # GroupNormalize(input_mean, input_std) 131 | # ]) 132 | 133 | def __getitem__(self, idx): 134 | question, answer = self.qa_template(self.data_list[idx]['data']) 135 | task_type = self.data_list[idx]['task_type'] 136 | decord_method = self.decord_method[self.data_list[idx]['data_type']] 137 | bound = None 138 | if self.data_list[idx]['bound']: 139 | bound = ( 140 | self.data_list[idx]['data']['start'], 141 | self.data_list[idx]['data']['end'], 142 | ) 143 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video']) 144 | 145 | 146 | images_group = decord_method(video_path, bound) 147 | 148 | return { 149 | 'video_path': video_path, 150 | 'video_pils': images_group, # some might use the original pils and do their own transforms 151 | 'question': question, 152 | 'answer': answer, 153 | 'task_type': task_type, 154 | } 155 | 156 | 157 | def qa_template(self, data): 158 | question = f"Question: {data['question']}\n" 159 | question += "Options:\n" 160 | answer = data['answer'] 161 | answer_idx = -1 162 | for idx, c in enumerate(data['candidates']): 163 | question += f"({chr(ord('A') + idx)}) {c}\n" 164 | if c == answer: 165 | answer_idx = idx 166 | question = question.rstrip() 167 | answer = f"({chr(ord('A') + answer_idx)}) {answer}" 168 | return question, answer 169 | 170 | -------------------------------------------------------------------------------- /tasks/eval/mvbench/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/mvbench/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tasks/eval/mvbench/__pycache__/llava_next_video_mvbench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/llava_next_video_mvbench.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/mvbench/__pycache__/pllava_eval_mvbench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/pllava_eval_mvbench.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/mvbench/__pycache__/tarsier_eval_mvbench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/mvbench/__pycache__/tarsier_eval_mvbench.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/recaption/show_recaption.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import gradio as gr 4 | 5 | from tasks.eval.recaption import load_results 6 | import json 7 | 8 | # example = videogallery().example_inputs() 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | '--save_path', 15 | required=True, 16 | ) 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | args = parse_args() 22 | result_list = load_results(args.save_path) 23 | 24 | 25 | def show(result_index, ): 26 | info = result_list[result_index] 27 | video_path = info['video_path'] 28 | info_str = json.dumps(info, indent=4) 29 | return video_path, info_str 30 | 31 | 32 | 33 | from tasks.eval.recaption import load_results 34 | 35 | with gr.Blocks() as demo: 36 | gr.Markdown("# Showing of what has came out.") 37 | gr.Markdown(f"From Saved Results {args.save_path}") 38 | with gr.Row(): 39 | with gr.Column(1): 40 | show_video = gr.Video(interactive=False) 41 | 42 | with gr.Column(): 43 | result_index = gr.Slider(0, len(result_list), step=1) 44 | info = gr.Text(interactive=False) 45 | 46 | result_index.change(show, [result_index], [show_video, info]) 47 | 48 | 49 | 50 | 51 | 52 | demo.launch(share=True) 53 | -------------------------------------------------------------------------------- /tasks/eval/vcgbench/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/vcgbench/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/vcgbench/__pycache__/pllava_eval_vcgbench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/vcgbench/__pycache__/pllava_eval_vcgbench.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/vcgbench/show_vcg.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import gradio as gr 4 | 5 | from tasks.eval.vcgbench import load_results 6 | import json 7 | 8 | # example = videogallery().example_inputs() 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | '--save_path', 15 | required=True, 16 | ) 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | args = parse_args() 22 | result_list = load_results(args.save_path) 23 | 24 | 25 | def show(result_index, ): 26 | info = result_list[result_index] 27 | video_path = info['video_path'] 28 | info_str = json.dumps(info, indent=4) 29 | return video_path, info_str 30 | 31 | with gr.Blocks() as demo: 32 | gr.Markdown( 33 | f"# Showing The Results from {args.save_path}" 34 | ) 35 | with gr.Row(): 36 | with gr.Column(): 37 | show_video = gr.Video(interactive=False) 38 | 39 | with gr.Column(): 40 | result_index = gr.Slider(0, len(result_list), step=1) 41 | info = gr.Text(interactive=False) 42 | 43 | result_index.change(show, [result_index], [show_video, info]) 44 | 45 | demo.launch(share=True) 46 | -------------------------------------------------------------------------------- /tasks/eval/videomme/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tasks.eval.eval_utils import ( 5 | dump_json, 6 | load_json, 7 | EvalDataset, 8 | ) 9 | 10 | 11 | def check_ans(pred, gt): 12 | flag = False 13 | 14 | pred_list = pred.lower().split(' ') 15 | pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:]) 16 | gt_list = gt.lower().split(' ') 17 | gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:]) 18 | if gt_content[-1] == '.': 19 | gt_content = gt_content[:-1] 20 | 21 | if not any([c in pred_option for c in 'abcdefgABCDEFG']): 22 | print(f"model doesn't follow instructions: {pred}") 23 | elif pred_option.replace('.', '') in gt_option: 24 | flag = True 25 | elif gt_option in pred_option: 26 | flag = True 27 | 28 | return flag 29 | 30 | def save_results(result_list, save_path): 31 | 32 | final_res, acc_dict = {}, {} 33 | correct, total = 0, 0 34 | for res in result_list: 35 | task_type = res['task_type'] 36 | if task_type not in acc_dict: 37 | acc_dict[task_type] = [0, 0] # correct, total 38 | acc_dict[task_type][1] += 1 39 | total += 1 40 | pred = res['pred'] 41 | gt = res['gt'] 42 | if check_ans(pred=pred, gt=gt): 43 | acc_dict[task_type][0] += 1 44 | correct += 1 45 | 46 | for k, v in acc_dict.items(): 47 | final_res[k] = v[0] / v[1] * 100 48 | correct += v[0] 49 | total += v[1] 50 | final_res['Avg'] = correct / total * 100 51 | 52 | all_results = { 53 | "acc_dict": acc_dict, 54 | "result_list": result_list 55 | } 56 | dump_json(all_results, save_path, 'all_results.json') 57 | dump_json(final_res, save_path, 'upload_leaderboard.json') 58 | 59 | def load_results(save_path): 60 | all_results = load_json(save_path, 'all_results.json') 61 | if all_results is not None: 62 | result_list = all_results['result_list'] 63 | else: 64 | result_list = None 65 | # json_data = load_json(save_path, 'all_results.json')['result_list'] 66 | return result_list 67 | 68 | class VideoMMEDataset(EvalDataset): 69 | data_list_info = { 70 | # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound") 71 | "Short Video": ("short.json", "DATAS/Video-MME/data", "video", False), # has start & end 72 | "Medium Video": ("medium.json", "DATAS/Video-MME/data", "video", False), # has start & end 73 | "Long Video": ("long.json", "DATAS/Video-MME/data", "video", False), 74 | } 75 | data_dir = "DATAS/Video-MME/json" 76 | 77 | def __init__(self, *args, **kwargs): 78 | super().__init__(*args, **kwargs) 79 | 80 | data_list_info = self.data_list_info 81 | data_dir = self.data_dir 82 | 83 | self.data_list = [] 84 | for k, v in data_list_info.items(): 85 | with open(os.path.join(data_dir, v[0]), 'r') as f: 86 | json_data = json.load(f) 87 | for data in json_data: 88 | self.data_list.append({ 89 | 'task_type': k, 90 | 'prefix': v[1], 91 | 'data_type': v[2], 92 | 'bound': v[3], 93 | 'data': data 94 | }) 95 | # self.data_list = self.data_list[:100] # for debug 96 | self.decord_method = { 97 | 'video': self.read_video, 98 | 'gif': self.read_gif, 99 | 'frame': self.read_frame, 100 | 'npy': self.read_npy, 101 | } 102 | 103 | # # transform 104 | # crop_size = resolution 105 | # scale_size = resolution 106 | # input_mean = [0.48145466, 0.4578275, 0.40821073] 107 | # input_std = [0.26862954, 0.26130258, 0.27577711] 108 | # self.transform = T.Compose([ 109 | # GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), 110 | # GroupCenterCrop(crop_size), 111 | # Stack(), 112 | # ToTorchFormatTensor(), 113 | # GroupNormalize(input_mean, input_std) 114 | # ]) 115 | 116 | def __getitem__(self, idx): 117 | question, answer = self.qa_template(self.data_list[idx]['data']) 118 | task_type = self.data_list[idx]['task_type'] 119 | decord_method = self.decord_method[self.data_list[idx]['data_type']] 120 | bound = None 121 | if self.data_list[idx]['bound']: 122 | bound = ( 123 | self.data_list[idx]['data']['start'], 124 | self.data_list[idx]['data']['end'], 125 | ) 126 | video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video']) 127 | 128 | 129 | # images_group = decord_method(video_path, bound) 130 | images_group = decord_method(video_path, bound) 131 | # try: # might be problem with decord 132 | # images_group = decord_method(video_path, bound) 133 | # except Exception as e: 134 | # print(f'error decoding {video_path}', e) 135 | # task_type = 'error_reading_video' 136 | # images_group = None 137 | 138 | return { 139 | 'video_path': video_path, 140 | 'video_pils': images_group, # some might use the original pils and do their own transforms 141 | 'question': question, 142 | 'answer': answer, 143 | 'task_type': task_type, 144 | } 145 | 146 | 147 | def qa_template(self, data): 148 | question = f"Question: {data['question']}\n" 149 | question += "Options:\n" 150 | answer = data['answer'] 151 | answer_idx = -1 152 | for idx, c in enumerate(data['candidates']): 153 | question += f"({chr(ord('A') + idx)}) {c}\n" 154 | if c == answer: 155 | answer_idx = idx 156 | question = question.rstrip() 157 | answer = f"({chr(ord('A') + answer_idx)}) {answer}" 158 | return question, answer 159 | 160 | -------------------------------------------------------------------------------- /tasks/eval/videomme/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videomme/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/videomme/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videomme/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /tasks/eval/videomme/__pycache__/pllava_eval_videomme.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videomme/__pycache__/pllava_eval_videomme.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/videoqabench/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videoqabench/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/eval/videoqabench/__pycache__/pllava_eval_videoqabench.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/eval/videoqabench/__pycache__/pllava_eval_videoqabench.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/shared_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import os.path as osp 5 | from os.path import join 6 | 7 | import torch 8 | from torch.utils.data import ConcatDataset, DataLoader 9 | 10 | from utils.optimizer import create_optimizer 11 | from utils.scheduler import create_scheduler 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_media_types(datasources): 17 | """get the media types for for all the dataloaders. 18 | 19 | Args: 20 | datasources (List): List of dataloaders or datasets. 21 | 22 | Returns: List. The media_types. 23 | 24 | """ 25 | if isinstance(datasources[0], DataLoader): 26 | datasets = [dataloader.dataset for dataloader in datasources] 27 | else: 28 | datasets = datasources 29 | media_types = [ 30 | dataset.datasets[0].media_type 31 | if isinstance(dataset, ConcatDataset) 32 | else dataset.media_type 33 | for dataset in datasets 34 | ] 35 | 36 | return media_types 37 | -------------------------------------------------------------------------------- /tasks/train/__pycache__/instruction_data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/train/__pycache__/instruction_data.cpython-310.pyc -------------------------------------------------------------------------------- /tasks/train/clever_process.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | dataset_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/CLEVRER' 5 | dir_list = os.listdir(dataset_path) 6 | 7 | for dir in dir_list: 8 | dir_path = os.path.join(dataset_path, dir) 9 | file_list = os.listdir(dir_path) 10 | for file in file_list: 11 | file_path = os.path.join(dir_path, file) 12 | shutil.move(file_path, dataset_path) -------------------------------------------------------------------------------- /tasks/train/config_pllava_nframe.py: -------------------------------------------------------------------------------- 1 | from tasks.train.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | # train_corpus = "videochat2_instruction" 5 | train_corpus = "videochat2_instruction_full" 6 | 7 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 8 | test_file = dict() 9 | test_types = [] 10 | num_workers = 8 11 | save_steps=1000 12 | ckpt_steps=1000 13 | stop_key = None 14 | deepspeed=False 15 | # ========================= input ========================== 16 | num_frames = 16 17 | num_frames_test = 1 18 | batch_size = 8 19 | gradient_accumulation_steps=1 20 | max_txt_l = 512 21 | max_train_steps=None 22 | pre_text = False 23 | inputs = dict( 24 | image_res=336, 25 | video_input=dict( 26 | num_frames="${num_frames}", 27 | sample_type="rand", 28 | num_frames_test="${num_frames_test}", 29 | sample_type_test="middle", 30 | random_aug=False, 31 | ), 32 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 33 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 34 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 35 | ) 36 | 37 | # ========================= model ========================== 38 | model = dict( 39 | # repo_id="llava-hf/llava-v1.6-vicuna-7b-hf", 40 | repo_id="MODELS/llava-1.6", 41 | # repo_id="MODELS/llava-1.6-7b-next-video-dpo", 42 | # repo_id="MODELS/tarsier", 43 | pretrained_path=None, 44 | load_from_origin=False, 45 | origin_vision="", 46 | origin_llm="", 47 | vision_encoder=dict( 48 | name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model 49 | ), 50 | torch_dtype='bfloat16', 51 | freeze_projector=False, 52 | projector_unfreeze_modules = ['all'], 53 | freeze_lm=True, 54 | lm_unfreeze_modules=['all'], 55 | # lm_unfreeze_modules=['layernorm', 'embed_tokens', 'norm', 'lm_head'], 56 | freeze_vision_tower=True, 57 | lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma 58 | use_lora=True, 59 | lora_r=128, 60 | lora_alpha=32, 61 | lora_dropout=0.05, 62 | num_frames="${num_frames}", 63 | pooling_method='avg', 64 | use_pooling=True, 65 | frame_shape=(24,24), 66 | pooling_shape=(16,8,8), 67 | ) 68 | 69 | preprocess = dict( 70 | system="", 71 | mm_alone=True, 72 | random_shuffle=True, 73 | add_second_msg=True, 74 | roles=['USER:', 'ASSISTANT:'], 75 | end_signal=(' ', ''), 76 | begin_signal='', 77 | dataset_image_placeholder='', 78 | dataset_video_placeholder='', 79 | image_token_index=32000, 80 | max_txt_l = "${max_txt_l}", 81 | ignore_index=-100, # same as torch softmax ignore index 82 | center_pad=False, 83 | longest_edge=762, 84 | shortest_edge=336, 85 | clip_transform=False, 86 | num_frames="${num_frames}", 87 | ) 88 | 89 | 90 | optimizer = dict( 91 | opt="adamW", 92 | lr=2e-5, 93 | opt_betas=[0.9, 0.999], # default 94 | weight_decay=0.02, 95 | max_grad_norm=-1, # requires a positive float, use -1 to disable 96 | # use a different lr for some modules, e.g., larger lr for new modules 97 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 98 | ) 99 | 100 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 101 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 102 | scheduler = dict( 103 | is_videochat2_custom=False, 104 | sched="cosine", 105 | epochs=2, 106 | warmup_ratio=0.2, 107 | min_lr_multi=0.25) 108 | 109 | evaluate = False 110 | deep_fusion = False 111 | evaluation = dict( 112 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 113 | eval_x_only=False, 114 | k_test=128, 115 | eval_offload=True, # offload gpu tensors to cpu to save memory. 116 | ) 117 | 118 | fp16 = True 119 | gradient_checkpointing = True 120 | 121 | # ========================= wandb ========================== 122 | wandb = dict( 123 | enable=False, 124 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 125 | project="DE_LLAVA", # setup in your command line 126 | ) 127 | dist_url = "env://" 128 | device = "cuda" 129 | mode = "it" 130 | 131 | # ========================= others ========================== 132 | output_dir = None # output dir 133 | resume = False # if True, load optimizer and scheduler states as well 134 | debug = False 135 | log_freq = 5 136 | metric_window_size=10 # window size for metric 137 | seed = 42 138 | report_to='tensorboard' 139 | save_latest = True 140 | auto_resume = True 141 | pretrained_path = "" # path to pretrained model weights, for resume only? 142 | -------------------------------------------------------------------------------- /tasks/train/config_pllava_nframe_yiprompt.py: -------------------------------------------------------------------------------- 1 | from tasks.train.instruction_data import * 2 | 3 | # ========================= data ========================== 4 | # train_corpus = "videochat2_instruction" 5 | train_corpus = "videochat2_instruction_full" 6 | 7 | train_file = "${available_corpus[${train_corpus}]}" # for lazy evaluation 8 | test_file = dict() 9 | test_types = [] 10 | num_workers = 8 11 | save_steps=10000 12 | ckpt_steps=1000 13 | stop_key = None 14 | deepspeed=False 15 | highres=None 16 | # ========================= input ========================== 17 | num_frames = 16 18 | num_frames_test = 1 19 | batch_size = 1 20 | gradient_accumulation_steps=16 21 | max_txt_l = 512 22 | max_train_steps=None 23 | pre_text = False 24 | inputs = dict( 25 | image_res=336, 26 | video_input=dict( 27 | num_frames="${num_frames}", 28 | sample_type="rand", 29 | num_frames_test="${num_frames_test}", 30 | sample_type_test="middle", 31 | random_aug=False, 32 | ), 33 | max_txt_l=dict(image="${max_txt_l}", video="${max_txt_l}"), 34 | batch_size=dict(image="${batch_size}", video="${batch_size}"), 35 | batch_size_test=dict(image="${batch_size}", video="${batch_size}"), 36 | ) 37 | 38 | model = dict( 39 | repo_id="llava-hf/llava-1.5-7b-hf", 40 | pretrained_path=None, 41 | load_from_origin=False, 42 | origin_vision="", 43 | origin_llm="", 44 | vision_encoder=dict( 45 | name="vit_l14", # somehow need this to tell the dataset the mean std of pretrained model 46 | ), 47 | torch_dtype='bfloat16', 48 | freeze_projector=False, 49 | freeze_lm=True, 50 | freeze_vision_tower=True, 51 | lora_target_modules=["q_proj", "v_proj"], # for llama/mistral/gemma 52 | use_lora=True, 53 | lora_r=128, 54 | lora_alpha=32, 55 | lora_dropout=0.05, 56 | num_frames="${num_frames}", 57 | pooling_method='avg', 58 | use_pooling=True, 59 | frame_shape=(24,24), 60 | pooling_shape=(16,8,8), 61 | ) 62 | preprocess = dict( 63 | system="", 64 | mm_alone=True, 65 | image_token_index=64002, 66 | random_shuffle=True, 67 | add_second_msg=True, 68 | roles=['<|im_start|>user\n', '<|im_start|>assistant\n'], 69 | end_signal=('<|im_end|>\n', '<|im_end|>\n'), 70 | begin_signal='', 71 | dataset_image_placeholder='', 72 | dataset_video_placeholder='', 73 | max_txt_l = "${max_txt_l}", 74 | ignore_index=-100, # same as torch softmax ignore index 75 | center_pad=False, 76 | longest_edge=762, 77 | shortest_edge=336, 78 | clip_transform=False, 79 | num_frames="${num_frames}", 80 | ) 81 | 82 | 83 | optimizer = dict( 84 | opt="adamW", 85 | lr=2e-5, 86 | opt_betas=[0.9, 0.999], # default 87 | weight_decay=0.02, 88 | max_grad_norm=-1, # requires a positive float, use -1 to disable 89 | # use a different lr for some modules, e.g., larger lr for new modules 90 | different_lr=dict(enable=False, module_names=[], lr=1e-3), 91 | ) 92 | 93 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 94 | # scheduler = dict(sched="cosine", epochs=3, min_lr_multi=0.25, warmup_epochs=0.6) 95 | scheduler = dict( 96 | is_videochat2_custom=False, 97 | sched="cosine", 98 | epochs=2, 99 | warmup_ratio=0.2, 100 | min_lr_multi=0.25) 101 | 102 | evaluate = False 103 | deep_fusion = False 104 | evaluation = dict( 105 | eval_frame_ensemble="concat", # [concat, max, mean, lse] 106 | eval_x_only=False, 107 | k_test=128, 108 | eval_offload=True, # offload gpu tensors to cpu to save memory. 109 | ) 110 | 111 | fp16 = True 112 | gradient_checkpointing = True 113 | 114 | # ========================= wandb ========================== 115 | wandb = dict( 116 | enable=False, 117 | entity="user", # username or team name to store the runs, see https://docs.wandb.ai/ref/python/init 118 | project="videochat2", # setup in your command line 119 | ) 120 | dist_url = "env://" 121 | device = "cuda" 122 | mode = "it" 123 | 124 | # ========================= others ========================== 125 | output_dir = None # output dir 126 | resume = False # if True, load optimizer and scheduler states as well 127 | debug = False 128 | log_freq = 5 129 | metric_window_size=10 # window size for metric 130 | seed = 42 131 | report_to='tensorboard' 132 | save_latest = True 133 | auto_resume = True 134 | pretrained_path = "" # path to pretrained model weights, for resume only? 135 | -------------------------------------------------------------------------------- /tasks/train/ego_process.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | anno_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/vqa/ego_qa/train.json' 3 | video_root_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/ego4d_data/split_videos' 4 | 5 | annos = json.load(open(anno_file, 'r')) 6 | for anno in annos: 7 | video_path = anno['video'] 8 | video_path = os.path.join(video_root_path, video_path) 9 | if not os.path.exists(video_path): 10 | print(video_path) -------------------------------------------------------------------------------- /tasks/train/ffmpeg_tgif.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | # 源文件夹路径 5 | source_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/video_gif' 6 | target_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/videos_mp4' 7 | # 包含文件名的文本文件路径 8 | file_list_path = 'not_have.txt' 9 | 10 | if not os.path.exists(target_folder): 11 | os.makedirs(target_folder) 12 | 13 | # 读取文件列表并转换 14 | with open(file_list_path, 'r') as file_list: 15 | for line in file_list: 16 | # 获取去除前后空白字符的文件名 17 | gif_filename = line.strip() 18 | # 源文件完整路径 19 | source_path = os.path.join(source_folder, gif_filename) 20 | # 目标文件完整路径,假设输入文件名格式正确,并将后缀替换为.mp4 21 | target_path = os.path.join(target_folder, os.path.splitext(gif_filename)[0] + '.mp4') 22 | 23 | # 构建ffmpeg命令 24 | cmd = ['ffmpeg', '-i', source_path, '-movflags', 'faststart', target_path] 25 | 26 | # 执行命令 27 | try: 28 | subprocess.run(cmd, check=True) 29 | print(f'Successfully converted {gif_filename} to MP4.') 30 | except subprocess.CalledProcessError as e: 31 | print(f'Failed to convert {gif_filename}. Error: {e}') 32 | 33 | print('All files have been processed.') -------------------------------------------------------------------------------- /tasks/train/k710_print.py: -------------------------------------------------------------------------------- 1 | dataset_path = { 2 | # 'k400': '/root/paddlejob/workspace/env_run/output/xiaohu/data/k400/train', 3 | # 'k600': '/root/paddlejob/workspace/env_run/output/xiaohu/data/k600/Kinetics600/videos', 4 | # 'k700': '/root/paddlejob/workspace/env_run/data_afs_3/zhouhao14/intern/xiaohu/k700_dir/Kinetics_700/videos/' 5 | 'k710': '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/k710' 6 | } 7 | 8 | import os 9 | from tqdm import tqdm 10 | 11 | f = open('k710_files_filter.txt', 'w') 12 | for dataset, path in dataset_path.items(): 13 | # dir_list = os.listdir(path) 14 | # for dir in tqdm(dir_list): 15 | # dir_path = os.path.join(path, dir) 16 | file_list = os.listdir(path) 17 | for file in file_list: 18 | file_path = os.path.join(path, file) 19 | f.write(file+' '+file_path+'\n') -------------------------------------------------------------------------------- /tasks/train/k710_process.py: -------------------------------------------------------------------------------- 1 | annotation_file = "/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/classification/k710/train_new.json" 2 | annotation_file_new = "/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/classification/k710/train_new_1.json" 3 | file_list = open('k710_files_filter.txt', 'r').readlines() 4 | file_list = [file.strip().split(' ') for file in file_list] 5 | file_dict = {} 6 | for file, path in file_list: 7 | file = file[:11].lower() 8 | file_dict[file] = path 9 | import os 10 | import json 11 | 12 | annotations = json.load(open(annotation_file)) 13 | print('annoation length:', len(annotations)) 14 | annotations_new = [] 15 | count = 0 16 | for anno in annotations: 17 | video_path = anno['video'] 18 | video_path = video_path.split('/')[-1].split('.')[0] 19 | if len(video_path) > 15: 20 | video_path = video_path[:11] 21 | video_path = video_path.lower() 22 | if video_path in file_dict: 23 | # anno['video'] = file_dict[video_path.lower()] 24 | anno['video'] = anno['video'].split('/')[-1] 25 | annotations_new.append(anno) 26 | else: 27 | count += 1 28 | json.dump(annotations_new, open(annotation_file_new, 'w')) 29 | print('miss number:', count) 30 | # for file, file_path in file_list: 31 | # if video_path in file: 32 | # continue 33 | # else: 34 | # print(video_path) -------------------------------------------------------------------------------- /tasks/train/mk_710.py: -------------------------------------------------------------------------------- 1 | annotation_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/classification/k710/train_new.json' 2 | dst_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/k710' 3 | 4 | import os, json, shutil 5 | from tqdm import tqdm 6 | 7 | f = open(annotation_file) 8 | annotations = json.load(f) 9 | 10 | for anno in tqdm(annotations): 11 | video_path = anno['video'] 12 | video_name = os.path.basename(video_path) 13 | shutil.copyfile(video_path, dst_path + '/' + video_name) -------------------------------------------------------------------------------- /tasks/train/not_have.txt: -------------------------------------------------------------------------------- 1 | tumblr_nb2b36uj4V1skspzwo1_250.gif 2 | tumblr_nnwhmq7vx91r76agyo1_250.gif 3 | tumblr_nq92uaE0Ws1u97vumo1_500.gif 4 | tumblr_mx7h7rT5VW1qd80wyo1_400.gif 5 | tumblr_ncfcoj7Cjk1qd80wyo1_400.gif 6 | tumblr_n7kvhniZkO1qd80wyo1_400.gif 7 | tumblr_nk56bs75dd1u7oomho1_400.gif 8 | tumblr_nk8cx74VZI1r88jv8o1_250.gif 9 | tumblr_npkvwxAmBw1ux8xe0o1_250.gif 10 | tumblr_nnwhsh94Cz1r76agyo1_100.gif 11 | tumblr_nqrmt7MmEi1ux8xe0o1_400.gif 12 | tumblr_nauyg863cl1tdjuqvo1_400.gif 13 | tumblr_n9gq572Eil1qd80wyo1_400.gif 14 | tumblr_np7f9w4gb61s4vkvgo1_250.gif 15 | tumblr_naemimnRQ21qj7ohio1_500.gif 16 | tumblr_npu4nvnG8y1ux8xe0o1_250.gif 17 | tumblr_ne372wjN501tmgpxuo1_250.gif 18 | tumblr_nqo7ly0WTQ1sgafh8o1_400.gif 19 | tumblr_nc669vjChQ1s7nakbo1_400.gif 20 | tumblr_njs0cnWbe11tgetb4o1_250.gif 21 | tumblr_nkils8vvnN1tk2dvro1_400.gif 22 | tumblr_n9vlgvJRfR1qd80wyo1_400.gif 23 | tumblr_n068ybhVsN1rkm4f7o1_400.gif 24 | tumblr_n3v3xrtaGc1r8go1ao1_250.gif 25 | tumblr_niibjr3cfO1u8uroco1_250.gif 26 | tumblr_nfa9ofxIIb1sk96t7o1_400.gif 27 | tumblr_nbjg4bukeX1raaknro1_250.gif 28 | tumblr_n8qummpIqS1sfcnmao1_250.gif 29 | tumblr_nnwk6xiIOh1r76agyo1_250.gif 30 | tumblr_n92cixcsvI1r88jv8o1_400.gif 31 | tumblr_nnwhoxZmzr1r76agyo1_100.gif 32 | tumblr_ncfslbyWaf1trjw2xo1_400.gif 33 | tumblr_nofiupp70K1tsywajo1_250.gif 34 | tumblr_nq3gnhRRgi1r09l2vo1_400.gif 35 | tumblr_naohj0KLo81sw0250o1_400.gif 36 | tumblr_nr13cnatIH1ux8xe0o1_250.gif 37 | tumblr_np8vddxRS81uw8t6bo1_400.gif 38 | tumblr_mtg3j27bd61s7nakbo1_400.gif 39 | tumblr_nh9ey9djGZ1s7nakbo1_400.gif 40 | tumblr_ngyd0yHHkR1s6jpovo1_400.gif 41 | tumblr_nk8tjfoZ5v1u7oomho1_400.gif 42 | tumblr_nf7mvh6bsr1qd80wyo1_400.gif 43 | tumblr_marc0jPgsb1qkq2eno1_400.gif 44 | -------------------------------------------------------------------------------- /tasks/train/output.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/tasks/train/output.mp4 -------------------------------------------------------------------------------- /tasks/train/print_all_files.py: -------------------------------------------------------------------------------- 1 | data_root = '/root/paddlejob/workspace/env_run/data_afs_3/zhouhao14/intern/xiaohu/webvid/webvid' 2 | 3 | import os 4 | from tqdm import tqdm 5 | f = open('webvid_list.txt', 'w') 6 | dir_list = os.listdir(data_root) 7 | for dir in tqdm(dir_list): 8 | dir_path = os.path.join(data_root, dir) 9 | file_list = os.listdir(dir_path) 10 | for file in file_list: 11 | file_path = os.path.join(dir_path, file) 12 | f.write(dir +'/'+ file + '\n') -------------------------------------------------------------------------------- /tasks/train/tgif_corrupt.txt: -------------------------------------------------------------------------------- 1 | tumblr_nkps43Kmm31unykvpo1_400.mp4 2 | tumblr_nhsbt0k1TK1rqx3tso1_500.mp4 3 | tumblr_nqidiuBkaX1tpg4boo1_250.mp4 4 | tumblr_nd8ikzEuRs1r4tjm5o1_400.mp4 5 | tumblr_no5qys6Zv91so4o4wo1_400.mp4 6 | tumblr_ne5vc0elB31slj978o1_400.mp4 7 | tumblr_nocmsxthAQ1tsqdy0o1_400.mp4 8 | tumblr_nqrdphoLSG1t95h1uo1_250.mp4 9 | tumblr_n9kbl6J6Tg1tgy7r4o1_400.mp4 10 | tumblr_nl0eirZXIH1tm778fo1_400.mp4 11 | tumblr_npoca1RD9h1s71nvbo1_400.mp4 12 | tumblr_nqkw6wcy7u1tg815ro1_400.mp4 13 | tumblr_nhglxePcKs1u713vko1_250.mp4 14 | tumblr_mlxk44iKda1rlryi1o1_400.mp4 15 | tumblr_nc7ou88Xda1tdmffyo1_250.mp4 16 | tumblr_npax48inl61tf42s3o1_400.mp4 17 | tumblr_n8uzopnOQA1tbgcpko1_400.mp4 18 | tumblr_njn1lj3sCB1unrob4o1_500.mp4 19 | tumblr_nozp1udlit1t95h1uo1_250.mp4 20 | tumblr_nmierrf5ZH1tnos68o1_250.mp4 21 | tumblr_nfm47yIKUE1rtequ6o1_400.mp4 22 | tumblr_navaivague1re06l8o1_400.mp4 23 | tumblr_nriiqfSIWQ1uaoehqo1_400.mp4 24 | tumblr_nfnj91Sxrc1tv4d9wo1_250.mp4 25 | tumblr_n8o5meR3HR1te77izo1_400.mp4 26 | tumblr_nq5yq0EmZB1u8gd00o1_250.mp4 27 | tumblr_nb2ol4NniN1tkpzw0o1_250.mp4 28 | tumblr_nh04cdAUje1sm9b1po2_400.mp4 29 | tumblr_nnjavclA4f1utipxro1_250.mp4 30 | tumblr_mv3ld4n2ri1sfprkzo1_400.mp4 31 | tumblr_nn91q1Yoa01uqzp8co1_500.mp4 32 | tumblr_ncl7lvFWpg1u04f66o1_400.mp4 33 | tumblr_nkwlv8Hg0U1qfq2gno1_400.mp4 34 | tumblr_no28osQSix1twfmf3o1_250.mp4 35 | tumblr_ne3958o3xv1qhrx75o1_250.mp4 36 | tumblr_nav9toudpf1re06l8o1_400.mp4 37 | tumblr_nfbuufZpod1u4068wo1_400.mp4 38 | tumblr_nonucoHIFu1tpg4boo1_250.mp4 39 | tumblr_naf1akCsYV1ts0kzio1_400.mp4 40 | tumblr_na6j8jESjb1thpigwo1_400.mp4 41 | tumblr_nkz4w2utsF1sm7eoto1_400.mp4 42 | tumblr_nigio1CXBZ1u8uroco1_400.mp4 43 | tumblr_noiq0clHCg1qzhjh2o1_400.mp4 44 | tumblr_nh74fcyMFu1slj978o1_250.mp4 45 | tumblr_nfyul0NhhS1tzs6b2o1_500.mp4 46 | tumblr_nf2oli6DJm1slw55qo1_400.mp4 47 | tumblr_npmwjgWiIA1uvie7bo1_400.mp4 48 | tumblr_nkq2vhomVm1twnkudo1_r1_400.mp4 49 | tumblr_niubadIAbL1tqviovo1_500.mp4 50 | tumblr_np3dvwowfN1up68h4o1_500.mp4 51 | tumblr_ngssnyMJqn1slj978o1_400.mp4 52 | tumblr_nfyqdyhWDn1sx7xv7o1_500.mp4 53 | tumblr_noe054KxAw1tpg4boo1_400.mp4 54 | tumblr_npj2p3o4361tx8mn0o1_400.mp4 55 | tumblr_ncv7bmm1lG1tf01j4o1_250.mp4 56 | tumblr_nlzaj75aY51s85u2fo1_500.mp4 57 | tumblr_mvz12eJxAB1rbf9bno1_500.mp4 58 | tumblr_n9oltoBCjr1t0ohh1o1_500.mp4 59 | tumblr_nm128eKJ7n1r9yho8o1_540.mp4 60 | tumblr_nejwjzrE2I1spote4o1_500.mp4 61 | tumblr_nrfz25aGMZ1ual9cno1_250.mp4 62 | tumblr_npdasiJEUc1sht3fmo1_400.mp4 63 | tumblr_njvhuklV531u2muk4o1_400.mp4 64 | tumblr_no566xiPbF1ttvor4o1_500.mp4 65 | tumblr_nd3sy96uW81tdvc4qo1_400.mp4 66 | tumblr_ne9hgkuZlD1twfpc5o1_500.mp4 67 | tumblr_m6qbi2ZSEE1qj7lb4o1_r3_500.mp4 68 | tumblr_naysx8YTzn1tzl1owo1_400.mp4 69 | tumblr_nk7hxlH63e1u2b31do1_400.mp4 70 | tumblr_nb6ejhQiUn1sl27r8o1_250.mp4 71 | tumblr_nqih1fBKzz1r0s2r6o1_250.mp4 72 | tumblr_nkb5fbXoNb1syz358o1_400.mp4 73 | tumblr_ngns8uQo831rq9gtvo1_400.mp4 74 | tumblr_nf9tyzpthC1tmddexo1_250.mp4 75 | tumblr_na7nn7XEI41rd6gi7o1_500.mp4 76 | tumblr_nknuza4Res1r3mh0to1_400.mp4 77 | tumblr_nfoxrvvztQ1tk8ub5o1_400.mp4 78 | tumblr_npzl3w7VUC1spbq2fo1_400.mp4 79 | tumblr_nlgaudRZPW1rw95g7o1_400.mp4 80 | tumblr_nnzysjVJR51r59fn4o1_400.mp4 81 | tumblr_nbv7fwPUde1shxl87o1_250.mp4 82 | tumblr_ni34ysM1ri1u3ztwyo1_250.mp4 83 | tumblr_nhq0b5Ukbk1tcof18o1_500.mp4 84 | tumblr_nqmswiqk121r8h6u4o1_500.mp4 85 | tumblr_n8z70012w81tgbvgqo1_500.mp4 86 | tumblr_nn61aaKcNA1qhqb9no1_r1_500.mp4 87 | tumblr_npjuc3Hc9r1qhrx75o1_400.mp4 88 | -------------------------------------------------------------------------------- /tasks/train/tgif_mp4.py: -------------------------------------------------------------------------------- 1 | import os 2 | from moviepy.editor import VideoFileClip 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from tqdm import tqdm 5 | from decord import VideoReader 6 | from decord import cpu, gpu 7 | 8 | 9 | # 源文件夹和目标文件夹 10 | source_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/video_gif' 11 | target_folder = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/VideoQA/TGIF_QA/videos_mp4' 12 | used_list = 'tgif_used.txt' 13 | corrupt_list = 'tgif_corrupt.txt' 14 | f_corrupt = open(corrupt_list, 'w') 15 | # 读取使用的文件列表 16 | with open(used_list, 'r') as f: 17 | used_files = [line.strip() for line in f.readlines()] 18 | 19 | count = 0 20 | not_count = 0 21 | miss_list = [] 22 | for file in tqdm(used_files): 23 | # file_new = file.replace('.gif', '.mp4') 24 | video_path = os.path.join(target_folder, file) 25 | if os.path.exists(video_path): 26 | try: 27 | vr = VideoReader(video_path, ctx=cpu(0)) 28 | except: 29 | count += 1 30 | miss_list.append(file) 31 | print('Error processing {}'.format(file)) 32 | f_corrupt.write(file + '\n') 33 | else: 34 | not_count += 1 35 | print(count, not_count) 36 | # print(os.path.join(source_folder, file)) 37 | # print(count, len(used_files)) 38 | # miss_list = [] 39 | 40 | # def convert_gif_to_mp4(file): 41 | # if file.endswith('.gif'): 42 | # source_path = os.path.join(source_folder, file) 43 | # target_path = os.path.join(target_folder, file.replace('.gif', '.mp4')) 44 | 45 | # if os.path.exists(target_path): 46 | # return f'{file} already converted. Skipping...' 47 | # try: 48 | # clip = VideoFileClip(source_path) 49 | # clip.write_videofile(target_path, codec="libx264", fps=24) 50 | # clip.close() 51 | 52 | # return f'Saved {file.replace(".gif", ".mp4")} to {target_folder}' 53 | # except Exception as e: 54 | # print('Error processing {}'.format(file), e, sep='\n') 55 | # # raise e 56 | # return f'Error processing {file}: {e}' 57 | # else: 58 | # return f'{file} is not a GIF. Skipping...' 59 | 60 | # # 设置线程池的最大线程数 61 | # max_threads = 8 62 | 63 | # with ThreadPoolExecutor(max_workers=max_threads) as executor: 64 | # # 使用executor.map来并行处理任务 65 | # # 注意:如果你想在任务执行时保持进度条更新,可能需要使用executor.submit和as_completed 66 | # futures = [executor.submit(convert_gif_to_mp4, file) for file in used_files] 67 | 68 | # # 为了展示进度条,我们使用as_completed来获取已完成的future 69 | # for future in tqdm(as_completed(futures), total=len(futures)): 70 | # print(future.result()) 71 | 72 | # print("所有视频处理完成!") -------------------------------------------------------------------------------- /tasks/train/vcg_process.py: -------------------------------------------------------------------------------- 1 | anno_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/conversation/videochatgpt/train_new.json' 2 | anno_new_file = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons/video/conversation/videochatgpt/train_new_1.json' 3 | 4 | data_root = '//root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/panda' 5 | import os, json 6 | from decord import VideoReader 7 | from decord import cpu, gpu 8 | from tqdm import tqdm 9 | 10 | 11 | miss_count = 0 12 | count = 0 13 | # annos = json.load(open(anno_file)) 14 | # annos_new = [] 15 | # for anno in tqdm(annos): 16 | # video_path = os.path.join(data_root, anno['video']) 17 | # if not os.path.exists(video_path): 18 | # continue 19 | # try: 20 | # vr = VideoReader(video_path, ctx=cpu(0)) 21 | # annos_new.append(anno) 22 | # except: 23 | # count += 1 24 | # json.dump(annos_new, open(anno_new_file, 'w')) 25 | # print(count) 26 | 27 | files = os.listdir(data_root) 28 | for file in tqdm(files): 29 | video_path = os.path.join(data_root, file) 30 | try: 31 | count += 1 32 | vr = VideoReader(video_path, ctx=cpu(0)) 33 | except: 34 | miss_count += 1 35 | print(video_path) 36 | print(miss_count, count) -------------------------------------------------------------------------------- /tasks/train/vcg_read.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | # 视频文件路径 4 | video_path = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/Video_ChatGPT/v_Z0eBz6QsI-c.mp4' 5 | 6 | # 打开视频文件 7 | cap = cv2.VideoCapture(video_path) 8 | 9 | while cap.isOpened(): 10 | # 读取一帧 11 | ret, frame = cap.read() 12 | 13 | # 如果正确读取帧,ret为True 14 | if not ret: 15 | print("Can't receive frame (stream end?). Exiting ...") 16 | break 17 | 18 | # 显示当前帧 19 | # cv2.imshow('frame', frame) 20 | 21 | # 按 'q' 退出 22 | if cv2.waitKey(1) == ord('q'): 23 | break 24 | 25 | # 释放Capture对象 26 | cap.release() 27 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /tasks/train/webvid_process.py: -------------------------------------------------------------------------------- 1 | file_list = 'webvid_list.txt' 2 | anno_root_it = '/root/paddlejob/workspace/env_run/output/xiaohu/data/video_vlm/PLLaVA/DATAS/TRAIN_TEST/magic_jsons' 3 | train_list_files = [f"{anno_root_it}/video/caption/videochat/train.json", f"{anno_root_it}/video/caption/webvid/train.json", f"{anno_root_it}/video/conversation/videochat1/train.json", f"{anno_root_it}/video/vqa/webvid_qa/train.json"] 4 | 5 | import os 6 | import json 7 | from tqdm import tqdm 8 | 9 | files = open(file_list).readlines() 10 | files = [file.strip() for file in files] 11 | f_missed = open('missing_files.txt', 'w') 12 | for file in train_list_files: 13 | f = open(file, 'r') 14 | item_list = [] 15 | data = json.load(f) 16 | for item in tqdm(data): 17 | video_id = item['video'] 18 | if video_id not in files: 19 | f_missed.write(video_id + '\n') 20 | else: 21 | item_list.append(item) 22 | new_file = file.replace('train', 'train_new') 23 | with open(new_file, 'w') as f_new: 24 | json.dump(item_list, f_new) -------------------------------------------------------------------------------- /utils/__pycache__/basic_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/basic_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/config_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/distributed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/distributed.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/easydict.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/easydict.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/optimizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/optimizer.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Visual-AI/PruneVid/b12600c6176c606ee4b6b7d9e26f4aa46d0dfd42/utils/__pycache__/scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /utils/basic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import json 5 | import logging 6 | import random 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | from pathlib import Path 11 | from typing import List, Union 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from .distributed import is_dist_avail_and_initialized 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class SmoothedValue(object): 22 | """Track a series of values and provide access to smoothed values over a 23 | window or the global series average. 24 | """ 25 | 26 | def __init__(self, window=20, fmt=None): 27 | if fmt is None: 28 | fmt = "{median:.4f} ({global_avg:.4f})" 29 | self.deque = deque(maxlen=window) 30 | self.total = 0.0 31 | self.count = 0 32 | self.fmt = fmt 33 | 34 | def update(self, value, n=1): 35 | self.deque.append(value) 36 | self.count += n 37 | self.total += value * n 38 | 39 | def synchronize_between_processes(self): 40 | """ 41 | Warning: does not synchronize the deque! 42 | """ 43 | if not is_dist_avail_and_initialized(): 44 | return 45 | t = torch.tensor([self.count, self.total], 46 | dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | if meter.count == 0: # skip empty meter 108 | loss_str.append( 109 | "{}: {}".format(name, "No data") 110 | ) 111 | else: 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def global_avg(self): 118 | loss_str = [] 119 | for name, meter in self.meters.items(): 120 | if meter.count == 0: 121 | loss_str.append( 122 | "{}: {}".format(name, "No data") 123 | ) 124 | else: 125 | loss_str.append( 126 | "{}: {:.4f}".format(name, meter.global_avg) 127 | ) 128 | return self.delimiter.join(loss_str) 129 | 130 | def get_global_avg_dict(self, prefix=""): 131 | """include a separator (e.g., `/`, or "_") at the end of `prefix`""" 132 | d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()} 133 | return d 134 | 135 | def synchronize_between_processes(self): 136 | for meter in self.meters.values(): 137 | meter.synchronize_between_processes() 138 | 139 | def add_meter(self, name, meter): 140 | self.meters[name] = meter 141 | 142 | def log_every(self, iterable, log_freq, header=None): 143 | i = 0 144 | if not header: 145 | header = '' 146 | start_time = time.time() 147 | end = time.time() 148 | iter_time = SmoothedValue(fmt='{avg:.4f}') 149 | data_time = SmoothedValue(fmt='{avg:.4f}') 150 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 151 | log_msg = [ 152 | header, 153 | '[{0' + space_fmt + '}/{1}]', 154 | 'eta: {eta}', 155 | '{meters}', 156 | 'time: {time}', 157 | 'data: {data}' 158 | ] 159 | if torch.cuda.is_available(): 160 | log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}') 161 | log_msg = self.delimiter.join(log_msg) 162 | MB = 1024.0 * 1024.0 163 | for obj in iterable: 164 | data_time.update(time.time() - end) 165 | yield obj 166 | iter_time.update(time.time() - end) 167 | if i % log_freq == 0 or i == len(iterable) - 1: 168 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 169 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 170 | if torch.cuda.is_available(): 171 | logger.info(log_msg.format( 172 | i, len(iterable), eta=eta_string, 173 | meters=str(self), 174 | time=str(iter_time), data=str(data_time), 175 | memory=torch.cuda.max_memory_allocated() / MB, 176 | res_mem=torch.cuda.max_memory_reserved() / MB, 177 | )) 178 | else: 179 | logger.info(log_msg.format( 180 | i, len(iterable), eta=eta_string, 181 | meters=str(self), 182 | time=str(iter_time), data=str(data_time))) 183 | i += 1 184 | end = time.time() 185 | total_time = time.time() - start_time 186 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 187 | logger.info('{} Total time: {} ({:.4f} s / it)'.format( 188 | header, total_time_str, total_time / len(iterable))) 189 | 190 | 191 | class AttrDict(dict): 192 | def __init__(self, *args, **kwargs): 193 | super(AttrDict, self).__init__(*args, **kwargs) 194 | self.__dict__ = self 195 | 196 | 197 | def compute_acc(logits, label, reduction='mean'): 198 | ret = (torch.argmax(logits, dim=1) == label).float() 199 | if reduction == 'none': 200 | return ret.detach() 201 | elif reduction == 'mean': 202 | return ret.mean().item() 203 | 204 | 205 | def compute_n_params(model, return_str=True): 206 | tot = 0 207 | for p in model.parameters(): 208 | w = 1 209 | for x in p.shape: 210 | w *= x 211 | tot += w 212 | if return_str: 213 | if tot >= 1e6: 214 | return '{:.1f}M'.format(tot / 1e6) 215 | else: 216 | return '{:.1f}K'.format(tot / 1e3) 217 | else: 218 | return tot 219 | 220 | 221 | def setup_seed(seed): 222 | torch.manual_seed(seed) 223 | np.random.seed(seed) 224 | random.seed(seed) 225 | 226 | 227 | def remove_files_if_exist(file_paths): 228 | for fp in file_paths: 229 | if os.path.isfile(fp): 230 | os.remove(fp) 231 | 232 | 233 | def save_json(data, filename, save_pretty=False, sort_keys=False): 234 | with open(filename, "w") as f: 235 | if save_pretty: 236 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) 237 | else: 238 | json.dump(data, f) 239 | 240 | 241 | def load_json(filename): 242 | with open(filename, "r") as f: 243 | return json.load(f) 244 | 245 | 246 | def flat_list_of_lists(l): 247 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" 248 | return [item for sublist in l for item in sublist] 249 | 250 | 251 | def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]): 252 | """ 253 | Args: 254 | root: path to the directory to start search files 255 | suffix: any str as suffix, or can match multiple such strings 256 | when input is List[str]. 257 | Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`] 258 | Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`. 259 | """ 260 | if isinstance(suffix, str): 261 | suffix = [suffix, ] 262 | filepaths = flat_list_of_lists( 263 | [list(Path(root).rglob(f"*{e}")) for e in suffix]) 264 | return filepaths 265 | 266 | 267 | def match_key_and_shape(state_dict1, state_dict2): 268 | keys1 = set(state_dict1.keys()) 269 | keys2 = set(state_dict2.keys()) 270 | print(f"keys1 - keys2: {keys1 - keys2}") 271 | print(f"keys2 - keys1: {keys2 - keys1}") 272 | 273 | mismatch = 0 274 | for k in list(keys1): 275 | if state_dict1[k].shape != state_dict2[k].shape: 276 | print( 277 | f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}") 278 | mismatch += 1 279 | print(f"mismatch {mismatch}") 280 | 281 | 282 | def merge_dicts(list_dicts): 283 | merged_dict = list_dicts[0].copy() 284 | for i in range(1, len(list_dicts)): 285 | merged_dict.update(list_dicts[i]) 286 | return merged_dict 287 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import ast 5 | import json 6 | import os 7 | import os.path as osp 8 | import re 9 | import shutil 10 | import sys 11 | import tempfile 12 | from copy import deepcopy 13 | from importlib import import_module 14 | 15 | import yaml 16 | 17 | from .easydict import EasyDict 18 | 19 | __all__ = ["Config", "pretty_text"] 20 | 21 | 22 | BASE_KEY = "_base_" 23 | # BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"} 24 | BASE_CONFIG = {} 25 | 26 | cfg = None 27 | 28 | 29 | class Config(object): 30 | """config""" 31 | 32 | @classmethod 33 | def pretty_text(cls, cfg: dict, indent=2) -> str: 34 | """format dict to a string 35 | 36 | Args: 37 | cfg (EasyDict): the params. 38 | 39 | Returns: The string to display. 40 | 41 | """ 42 | msg = "{\n" 43 | for i, (k, v) in enumerate(cfg.items()): 44 | if isinstance(v, dict): 45 | v = cls.pretty_text(v, indent + 4) 46 | spaces = " " * indent 47 | msg += spaces + "{}: {}".format(k, v) 48 | if i == len(cfg) - 1: 49 | msg += " }" 50 | else: 51 | msg += "\n" 52 | return msg 53 | 54 | @classmethod 55 | def dump(cls, cfg, savepath=None): 56 | """dump cfg to `json` file. 57 | 58 | Args: 59 | cfg (dict): The dict to dump. 60 | savepath (str): The filepath to save the dumped dict. 61 | 62 | Returns: TODO 63 | 64 | """ 65 | if savepath is None: 66 | savepath = osp.join(cfg.WORKSPACE, "config.json") 67 | json.dump(cfg, open(savepath, "w"), indent=2) 68 | 69 | @classmethod 70 | def get_config(cls, default_config: dict = None): 71 | """get a `Config` instance. 72 | 73 | Args: 74 | default_config (dict): The default config. `default_config` will be overrided 75 | by config file `--cfg`, `--cfg` will be overrided by commandline args. 76 | 77 | Returns: an EasyDict. 78 | """ 79 | global cfg 80 | if cfg is not None: 81 | return cfg 82 | 83 | # define arg parser. 84 | parser = argparse.ArgumentParser() 85 | # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str) 86 | parser.add_argument( 87 | "config_file", help="the configuration file to load. support: .yaml, .json, .py" 88 | ) 89 | parser.add_argument( 90 | "opts", 91 | default=None, 92 | nargs="*", 93 | help="overrided configs. List. Format: 'key1 name1 key2 name2'", 94 | ) 95 | args = parser.parse_args() 96 | 97 | cfg = EasyDict(BASE_CONFIG) 98 | if osp.isfile(args.config_file): 99 | cfg_from_file = cls.from_file(args.config_file) 100 | cfg = merge_a_into_b(cfg_from_file, cfg) 101 | cfg = cls.merge_list(cfg, args.opts) 102 | cfg = eval_dict_leaf(cfg) 103 | 104 | # update some keys to make them show at the last 105 | for k in BASE_CONFIG: 106 | cfg[k] = cfg.pop(k) 107 | return cfg 108 | 109 | @classmethod 110 | def from_file(cls, filepath: str) -> EasyDict: 111 | """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`. 112 | 113 | Args: 114 | filepath (str): The config file path. 115 | 116 | Returns: TODO 117 | 118 | """ 119 | filepath = osp.abspath(osp.expanduser(filepath)) 120 | if not osp.isfile(filepath): 121 | raise IOError(f"File does not exist: {filepath}") 122 | if filepath.endswith(".py"): 123 | with tempfile.TemporaryDirectory() as temp_config_dir: 124 | 125 | shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config")) 126 | sys.path.insert(0, temp_config_dir) 127 | mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0]) 128 | # mod = import_module(temp_module_name) 129 | sys.path.pop(0) 130 | cfg_dict = { 131 | name: value 132 | for name, value in mod.__dict__.items() 133 | if not name.startswith("__") 134 | } 135 | for k in list(sys.modules.keys()): 136 | if "tmp_config" in k: 137 | del sys.modules[k] 138 | elif filepath.endswith((".yml", ".yaml")): 139 | cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader) 140 | elif filepath.endswith(".json"): 141 | cfg_dict = json.load(open(filepath, "r")) 142 | else: 143 | raise IOError("Only py/yml/yaml/json type are supported now!") 144 | 145 | cfg_text = filepath + "\n" 146 | with open(filepath, "r") as f: 147 | cfg_text += f.read() 148 | 149 | if BASE_KEY in cfg_dict: # load configs in `BASE_KEY` 150 | cfg_dir = osp.dirname(filepath) 151 | base_filename = cfg_dict.pop(BASE_KEY) 152 | base_filename = ( 153 | base_filename if isinstance(base_filename, list) else [base_filename] 154 | ) 155 | 156 | cfg_dict_list = list() 157 | for f in base_filename: 158 | _cfg_dict = Config.from_file(osp.join(cfg_dir, f)) 159 | cfg_dict_list.append(_cfg_dict) 160 | 161 | base_cfg_dict = dict() 162 | for c in cfg_dict_list: 163 | if len(base_cfg_dict.keys() & c.keys()) > 0: 164 | raise KeyError("Duplicate key is not allowed among bases") 165 | base_cfg_dict.update(c) 166 | 167 | cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict) 168 | 169 | return EasyDict(cfg_dict) 170 | 171 | @classmethod 172 | def merge_list(cls, cfg, opts: list): 173 | """merge commandline opts. 174 | 175 | Args: 176 | cfg: (dict): The config to be merged. 177 | opts (list): The list to merge. Format: [key1, name1, key2, name2,...]. 178 | The keys can be nested. For example, ["a.b", v] will be considered 179 | as `dict(a=dict(b=v))`. 180 | 181 | Returns: dict. 182 | 183 | """ 184 | assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}" 185 | for i in range(0, len(opts), 2): 186 | full_k, v = opts[i], opts[i + 1] 187 | keys = full_k.split(".") 188 | sub_d = cfg 189 | for i, k in enumerate(keys): 190 | if not hasattr(sub_d, k): 191 | raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}") 192 | if i != len(keys) - 1: 193 | sub_d = sub_d[k] 194 | else: 195 | sub_d[k] = v 196 | return cfg 197 | 198 | 199 | def merge_a_into_b(a, b, inplace=False): 200 | """The values in a will override values in b. 201 | 202 | Args: 203 | a (dict): source dict. 204 | b (dict): target dict. 205 | 206 | Returns: dict. recursively merge dict a into dict b. 207 | 208 | """ 209 | if not inplace: 210 | b = deepcopy(b) 211 | for key in a: 212 | if key in b: 213 | if isinstance(a[key], dict) and isinstance(b[key], dict): 214 | b[key] = merge_a_into_b(a[key], b[key], inplace=True) 215 | else: 216 | b[key] = a[key] 217 | else: 218 | b[key] = a[key] 219 | return b 220 | 221 | 222 | def eval_dict_leaf(d, orig_dict=None): 223 | """eval values of dict leaf. 224 | 225 | Args: 226 | d (dict): The dict to eval. 227 | 228 | Returns: dict. 229 | 230 | """ 231 | if orig_dict is None: 232 | orig_dict = d 233 | for k, v in d.items(): 234 | if not isinstance(v, dict): 235 | d[k] = eval_string(v, orig_dict) 236 | else: 237 | eval_dict_leaf(v, orig_dict) 238 | return d 239 | 240 | 241 | def eval_string(string, d): 242 | """automatically evaluate string to corresponding types. 243 | 244 | For example: 245 | not a string -> return the original input 246 | '0' -> 0 247 | '0.2' -> 0.2 248 | '[0, 1, 2]' -> [0,1,2] 249 | 'eval(1+2)' -> 3 250 | 'eval(range(5))' -> [0,1,2,3,4] 251 | '${a}' -> d.a 252 | 253 | 254 | 255 | Args: 256 | string (str): The value to evaluate. 257 | d (dict): The 258 | 259 | Returns: the corresponding type 260 | 261 | """ 262 | if not isinstance(string, str): 263 | return string 264 | # if len(string) > 1 and string[0] == "[" and string[-1] == "]": 265 | # return eval(string) 266 | if string[0:5] == "eval(": 267 | return eval(string[5:-1]) 268 | 269 | s0 = string 270 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0) 271 | if s1 != s0: 272 | while s1 != s0: 273 | s0 = s1 274 | s1 = re.sub(r"\${(.*)}", r"d.\1", s0) 275 | return eval(s1) 276 | 277 | try: 278 | v = ast.literal_eval(string) 279 | except: 280 | v = string 281 | return v 282 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from os.path import dirname, join 5 | 6 | from utils.config import Config 7 | from utils.distributed import init_distributed_mode, is_main_process 8 | from utils.logger import setup_logger 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def setup_config(): 14 | """Conbine yaml config and command line config with OmegaConf. 15 | Also converts types, e.g., `'None'` (str) --> `None` (None) 16 | """ 17 | config = Config.get_config() 18 | if config.debug: 19 | config.wandb.enable = False 20 | return config 21 | 22 | 23 | def setup_evaluate_config(config): 24 | """setup evaluation default settings, e.g., disable wandb""" 25 | assert config.evaluate 26 | config.wandb.enable = False 27 | if config.output_dir is None: 28 | config.output_dir = join(dirname(config.pretrained_path), "eval") 29 | return config 30 | 31 | 32 | def setup_output_dir(output_dir, excludes=["code"]): 33 | """ensure not overwritting an exisiting/non-empty output dir""" 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir, exist_ok=False) 36 | else: 37 | existing_dirs_files = os.listdir(output_dir) # list 38 | remaining = set(existing_dirs_files) - set(excludes) 39 | remaining = [e for e in remaining if "slurm" not in e] 40 | remaining = [e for e in remaining if ".out" not in e] 41 | # assert len(remaining) == 0, f"remaining dirs or files: {remaining}" 42 | logger.warn(f"remaining dirs or files: {remaining}") 43 | 44 | 45 | def setup_main(): 46 | """ 47 | Setup config, logger, output_dir, etc. 48 | Shared for pretrain and all downstream tasks. 49 | """ 50 | config = setup_config() 51 | if hasattr(config, "evaluate") and config.evaluate: 52 | config = setup_evaluate_config(config) 53 | init_distributed_mode(config) 54 | 55 | if is_main_process(): 56 | setup_output_dir(config.output_dir, excludes=["code"]) 57 | setup_logger(output=config.output_dir, color=True, name="vindlu") 58 | logger.info(f"config: {Config.pretty_text(config)}") 59 | Config.dump(config, os.path.join(config.output_dir, "config.json")) 60 | return config 61 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import logging 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def setup_for_distributed(is_master): 11 | import warnings 12 | 13 | builtin_warn = warnings.warn 14 | 15 | def warn(*args, **kwargs): 16 | force = kwargs.pop("force", False) 17 | if is_master or force: 18 | builtin_warn(*args, **kwargs) 19 | 20 | # Log warnings only once 21 | warnings.warn = warn 22 | warnings.simplefilter("once", UserWarning) 23 | 24 | if not is_master: 25 | logging.disable() 26 | 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | 36 | def get_world_size(): 37 | if not is_dist_avail_and_initialized(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def get_rank(): 43 | if not is_dist_avail_and_initialized(): 44 | return 0 45 | return dist.get_rank() 46 | 47 | 48 | def is_main_process(): 49 | return get_rank() == 0 50 | 51 | 52 | def save_on_master(*args, **kwargs): 53 | if is_main_process(): 54 | torch.save(*args, **kwargs) 55 | 56 | 57 | def is_port_in_use(port): 58 | import socket 59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 60 | return s.connect_ex(('localhost', port)) == 0 61 | 62 | 63 | def init_distributed_mode(args): 64 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 65 | # job started by torch.distributed.launch 66 | args.rank = int(os.environ["RANK"]) 67 | args.world_size = int(os.environ['WORLD_SIZE']) 68 | args.gpu = int(os.environ['LOCAL_RANK']) 69 | elif 'SLURM_PROCID' in os.environ: 70 | # local rank on the current node / global rank 71 | local_rank = int(os.environ['SLURM_LOCALID']) 72 | global_rank = int(os.environ['SLURM_PROCID']) 73 | # number of processes / GPUs per node 74 | world_size = int(os.environ["SLURM_NNODES"]) * \ 75 | int(os.environ["SLURM_TASKS_PER_NODE"][0]) 76 | 77 | print(world_size) 78 | 79 | args.rank = global_rank 80 | args.gpu = local_rank 81 | args.world_size = world_size 82 | else: 83 | logger.info('Not using distributed mode') 84 | args.distributed = False 85 | return 86 | 87 | args.distributed = True 88 | 89 | torch.cuda.set_device(args.gpu) 90 | args.dist_backend = 'nccl' 91 | 92 | if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node 93 | dist_port = int(args.dist_url.split(":")[-1]) 94 | while is_port_in_use(dist_port): 95 | dist_port += 10 96 | args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)]) 97 | print(args.dist_url) 98 | 99 | logger.info('| distributed init (rank {}): {}'.format( 100 | args.rank, args.dist_url)) 101 | if "SLURM_JOB_ID" in os.environ: 102 | logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}") 103 | torch.distributed.init_process_group( 104 | backend=args.dist_backend, init_method=args.dist_url, 105 | world_size=args.world_size, rank=args.rank) 106 | torch.distributed.barrier() 107 | setup_for_distributed(args.rank == 0) 108 | 109 | 110 | # Copyright (c) Facebook, Inc. and its affiliates. 111 | # copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py 112 | class GatherLayer(torch.autograd.Function): 113 | """ 114 | Gather tensors from all workers with support for backward propagation: 115 | This implementation does not cut the gradients as torch.distributed.all_gather does. 116 | """ 117 | 118 | @staticmethod 119 | def forward(ctx, x): 120 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 121 | dist.all_gather(output, x) 122 | return tuple(output) 123 | 124 | @staticmethod 125 | def backward(ctx, *grads): 126 | all_gradients = torch.stack(grads) 127 | dist.all_reduce(all_gradients) 128 | return all_gradients[dist.get_rank()] 129 | 130 | 131 | # copied from megavlt 132 | def gather_tensor_along_batch_with_backward(tensor, dim=0): 133 | world_size = get_world_size() 134 | 135 | if world_size < 2: 136 | return tensor 137 | 138 | tensor_list = GatherLayer.apply(tensor) 139 | tensor_list = torch.cat(tensor_list, dim=dim) 140 | return tensor_list 141 | 142 | 143 | @torch.no_grad() 144 | def gather_tensor_along_batch(tensor, dim=0): 145 | """ 146 | Performs all_gather operation on the provided tensors. 147 | *** Warning ***: torch.distributed.all_gather has no gradient. 148 | """ 149 | world_size = get_world_size() 150 | 151 | if world_size < 2: 152 | return tensor 153 | 154 | with torch.no_grad(): 155 | tensor_list = [] 156 | 157 | for _ in range(world_size): 158 | tensor_list.append(torch.zeros_like(tensor)) 159 | 160 | dist.all_gather(tensor_list, tensor) 161 | tensor_list = torch.cat(tensor_list, dim=dim) 162 | return tensor_list 163 | -------------------------------------------------------------------------------- /utils/easydict.py: -------------------------------------------------------------------------------- 1 | class EasyDict(dict): 2 | """ 3 | Get attributes 4 | 5 | >>> d = EasyDict({'foo':3}) 6 | >>> d['foo'] 7 | 3 8 | >>> d.foo 9 | 3 10 | >>> d.bar 11 | Traceback (most recent call last): 12 | ... 13 | AttributeError: 'EasyDict' object has no attribute 'bar' 14 | 15 | Works recursively 16 | 17 | >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) 18 | >>> isinstance(d.bar, dict) 19 | True 20 | >>> d.bar.x 21 | 1 22 | 23 | Bullet-proof 24 | 25 | >>> EasyDict({}) 26 | {} 27 | >>> EasyDict(d={}) 28 | {} 29 | >>> EasyDict(None) 30 | {} 31 | >>> d = {'a': 1} 32 | >>> EasyDict(**d) 33 | {'a': 1} 34 | 35 | Set attributes 36 | 37 | >>> d = EasyDict() 38 | >>> d.foo = 3 39 | >>> d.foo 40 | 3 41 | >>> d.bar = {'prop': 'value'} 42 | >>> d.bar.prop 43 | 'value' 44 | >>> d 45 | {'foo': 3, 'bar': {'prop': 'value'}} 46 | >>> d.bar.prop = 'newer' 47 | >>> d.bar.prop 48 | 'newer' 49 | 50 | 51 | Values extraction 52 | 53 | >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) 54 | >>> isinstance(d.bar, list) 55 | True 56 | >>> from operator import attrgetter 57 | >>> map(attrgetter('x'), d.bar) 58 | [1, 3] 59 | >>> map(attrgetter('y'), d.bar) 60 | [2, 4] 61 | >>> d = EasyDict() 62 | >>> d.keys() 63 | [] 64 | >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) 65 | >>> d.foo 66 | 3 67 | >>> d.bar.x 68 | 1 69 | 70 | Still like a dict though 71 | 72 | >>> o = EasyDict({'clean':True}) 73 | >>> o.items() 74 | [('clean', True)] 75 | 76 | And like a class 77 | 78 | >>> class Flower(EasyDict): 79 | ... power = 1 80 | ... 81 | >>> f = Flower() 82 | >>> f.power 83 | 1 84 | >>> f = Flower({'height': 12}) 85 | >>> f.height 86 | 12 87 | >>> f['power'] 88 | 1 89 | >>> sorted(f.keys()) 90 | ['height', 'power'] 91 | 92 | update and pop items 93 | >>> d = EasyDict(a=1, b='2') 94 | >>> e = EasyDict(c=3.0, a=9.0) 95 | >>> d.update(e) 96 | >>> d.c 97 | 3.0 98 | >>> d['c'] 99 | 3.0 100 | >>> d.get('c') 101 | 3.0 102 | >>> d.update(a=4, b=4) 103 | >>> d.b 104 | 4 105 | >>> d.pop('a') 106 | 4 107 | >>> d.a 108 | Traceback (most recent call last): 109 | ... 110 | AttributeError: 'EasyDict' object has no attribute 'a' 111 | """ 112 | 113 | def __init__(self, d=None, **kwargs): 114 | if d is None: 115 | d = {} 116 | if kwargs: 117 | d.update(**kwargs) 118 | for k, v in d.items(): 119 | setattr(self, k, v) 120 | # Class attributes 121 | for k in self.__class__.__dict__.keys(): 122 | if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"): 123 | setattr(self, k, getattr(self, k)) 124 | 125 | def __setattr__(self, name, value): 126 | if isinstance(value, (list, tuple)): 127 | value = [self.__class__(x) if isinstance(x, dict) else x for x in value] 128 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 129 | value = self.__class__(value) 130 | super(EasyDict, self).__setattr__(name, value) 131 | super(EasyDict, self).__setitem__(name, value) 132 | 133 | __setitem__ = __setattr__ 134 | 135 | def update(self, e=None, **f): 136 | d = e or dict() 137 | d.update(f) 138 | for k in d: 139 | setattr(self, k, d[k]) 140 | 141 | def pop(self, k, d=None): 142 | if hasattr(self, k): 143 | delattr(self, k) 144 | return super(EasyDict, self).pop(k, d) 145 | 146 | 147 | if __name__ == "__main__": 148 | import doctest 149 | 150 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import functools 5 | import logging 6 | import os 7 | import sys 8 | import time 9 | import wandb 10 | from typing import Any, Dict, Union 11 | 12 | import torch 13 | from .distributed import get_rank, is_main_process 14 | from termcolor import colored 15 | 16 | 17 | def log_dict_to_wandb(log_dict, step, prefix=""): 18 | """include a separator `/` at the end of `prefix`""" 19 | if not is_main_process(): 20 | return 21 | 22 | log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()} 23 | wandb.log(log_dict, step) 24 | 25 | 26 | def setup_wandb(config): 27 | if not (config.wandb.enable and is_main_process()): 28 | return 29 | 30 | run = wandb.init( 31 | config=config, 32 | project=config.wandb.project, 33 | entity=config.wandb.entity, 34 | name=os.path.basename(config.output_dir), 35 | reinit=True 36 | ) 37 | return run 38 | 39 | 40 | def setup_output_folder(save_dir: str, folder_only: bool = False): 41 | """Sets up and returns the output file where the logs will be placed 42 | based on the configuration passed. Usually "save_dir/logs/log_.txt". 43 | If env.log_dir is passed, logs will be directly saved in this folder. 44 | Args: 45 | folder_only (bool, optional): If folder should be returned and not the file. 46 | Defaults to False. 47 | Returns: 48 | str: folder or file path depending on folder_only flag 49 | """ 50 | log_filename = "train_" 51 | log_filename += time.strftime("%Y_%m_%dT%H_%M_%S") 52 | log_filename += ".log" 53 | 54 | log_folder = os.path.join(save_dir, "logs") 55 | 56 | if not os.path.exists(log_folder): 57 | os.path.mkdirs(log_folder) 58 | 59 | if folder_only: 60 | return log_folder 61 | 62 | log_filename = os.path.join(log_folder, log_filename) 63 | 64 | return log_filename 65 | 66 | 67 | def setup_logger( 68 | output: str = None, 69 | color: bool = True, 70 | name: str = "mmf", 71 | disable: bool = False, 72 | clear_handlers=True, 73 | *args, 74 | **kwargs, 75 | ): 76 | """ 77 | Initialize the MMF logger and set its verbosity level to "INFO". 78 | Outside libraries shouldn't call this in case they have set there 79 | own logging handlers and setup. If they do, and don't want to 80 | clear handlers, pass clear_handlers options. 81 | The initial version of this function was taken from D2 and adapted 82 | for MMF. 83 | Args: 84 | output (str): a file name or a directory to save log. 85 | If ends with ".txt" or ".log", assumed to be a file name. 86 | Default: Saved to file 87 | color (bool): If false, won't log colored logs. Default: true 88 | name (str): the root module name of this logger. Defaults to "mmf". 89 | disable: do not use 90 | clear_handlers (bool): If false, won't clear existing handlers. 91 | Returns: 92 | logging.Logger: a logger 93 | """ 94 | if disable: 95 | return None 96 | logger = logging.getLogger(name) 97 | logger.propagate = False 98 | 99 | logging.captureWarnings(True) 100 | warnings_logger = logging.getLogger("py.warnings") 101 | 102 | plain_formatter = logging.Formatter( 103 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s", 104 | datefmt="%Y-%m-%dT%H:%M:%S", 105 | ) 106 | 107 | distributed_rank = get_rank() 108 | handlers = [] 109 | 110 | logging_level = logging.INFO 111 | # logging_level = logging.DEBUG 112 | 113 | if distributed_rank == 0: 114 | logger.setLevel(logging_level) 115 | ch = logging.StreamHandler(stream=sys.stdout) 116 | ch.setLevel(logging_level) 117 | if color: 118 | formatter = ColorfulFormatter( 119 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", 120 | datefmt="%Y-%m-%dT%H:%M:%S", 121 | ) 122 | else: 123 | formatter = plain_formatter 124 | ch.setFormatter(formatter) 125 | logger.addHandler(ch) 126 | warnings_logger.addHandler(ch) 127 | handlers.append(ch) 128 | 129 | # file logging: all workers 130 | if output is None: 131 | output = setup_output_folder() 132 | 133 | if output is not None: 134 | if output.endswith(".txt") or output.endswith(".log"): 135 | filename = output 136 | else: 137 | filename = os.path.join(output, "train.log") 138 | if distributed_rank > 0: 139 | filename = filename + f".rank{distributed_rank}" 140 | os.makedirs(os.path.dirname(filename), exist_ok=True) 141 | 142 | fh = logging.StreamHandler(_cached_log_stream(filename)) 143 | fh.setLevel(logging_level) 144 | fh.setFormatter(plain_formatter) 145 | logger.addHandler(fh) 146 | warnings_logger.addHandler(fh) 147 | handlers.append(fh) 148 | 149 | # Slurm/FB output, only log the main process 150 | # save_dir = get_mmf_env(key="save_dir") 151 | if "train.log" not in filename and distributed_rank == 0: 152 | filename = os.path.join(output, "train.log") 153 | sh = logging.StreamHandler(_cached_log_stream(filename)) 154 | sh.setLevel(logging_level) 155 | sh.setFormatter(plain_formatter) 156 | logger.addHandler(sh) 157 | warnings_logger.addHandler(sh) 158 | handlers.append(sh) 159 | 160 | logger.info(f"Logging to: {filename}") 161 | 162 | # Remove existing handlers to add MMF specific handlers 163 | if clear_handlers: 164 | for handler in logging.root.handlers[:]: 165 | logging.root.removeHandler(handler) 166 | # Now, add our handlers. 167 | logging.basicConfig(level=logging_level, handlers=handlers) 168 | 169 | return logger 170 | 171 | 172 | def setup_very_basic_config(color=True): 173 | plain_formatter = logging.Formatter( 174 | "%(asctime)s | %(levelname)s | %(name)s : %(message)s", 175 | datefmt="%Y-%m-%dT%H:%M:%S", 176 | ) 177 | ch = logging.StreamHandler(stream=sys.stdout) 178 | ch.setLevel(logging.INFO) 179 | if color: 180 | formatter = ColorfulFormatter( 181 | colored("%(asctime)s | %(name)s: ", "green") + "%(message)s", 182 | datefmt="%Y-%m-%dT%H:%M:%S", 183 | ) 184 | else: 185 | formatter = plain_formatter 186 | ch.setFormatter(formatter) 187 | # Setup a minimal configuration for logging in case something tries to 188 | # log a message even before logging is setup by MMF. 189 | logging.basicConfig(level=logging.INFO, handlers=[ch]) 190 | 191 | 192 | # cache the opened file object, so that different calls to `setup_logger` 193 | # with the same file name can safely write to the same file. 194 | @functools.lru_cache(maxsize=None) 195 | def _cached_log_stream(filename): 196 | return open(filename, "a") 197 | 198 | 199 | # ColorfulFormatter is adopted from Detectron2 and adapted for MMF 200 | class ColorfulFormatter(logging.Formatter): 201 | def __init__(self, *args, **kwargs): 202 | super().__init__(*args, **kwargs) 203 | 204 | def formatMessage(self, record): 205 | log = super().formatMessage(record) 206 | if record.levelno == logging.WARNING: 207 | prefix = colored("WARNING", "red", attrs=["blink"]) 208 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 209 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 210 | else: 211 | return log 212 | return prefix + " " + log 213 | 214 | 215 | class TensorboardLogger: 216 | def __init__(self, log_folder="./logs", iteration=0): 217 | # This would handle warning of missing tensorboard 218 | from torch.utils.tensorboard import SummaryWriter 219 | 220 | self.summary_writer = None 221 | self._is_master = is_main_process() 222 | # self.timer = Timer() 223 | self.log_folder = log_folder 224 | 225 | if self._is_master: 226 | # current_time = self.timer.get_time_hhmmss(None, format=self.time_format) 227 | current_time = time.strftime("%Y-%m-%dT%H:%M:%S") 228 | # self.timer.get_time_hhmmss(None, format=self.time_format) 229 | tensorboard_folder = os.path.join( 230 | self.log_folder, f"tensorboard_{current_time}" 231 | ) 232 | self.summary_writer = SummaryWriter(tensorboard_folder) 233 | 234 | def __del__(self): 235 | if getattr(self, "summary_writer", None) is not None: 236 | self.summary_writer.close() 237 | 238 | def _should_log_tensorboard(self): 239 | if self.summary_writer is None or not self._is_master: 240 | return False 241 | else: 242 | return True 243 | 244 | def add_scalar(self, key, value, iteration): 245 | if not self._should_log_tensorboard(): 246 | return 247 | 248 | self.summary_writer.add_scalar(key, value, iteration) 249 | 250 | def add_scalars(self, scalar_dict, iteration): 251 | if not self._should_log_tensorboard(): 252 | return 253 | 254 | for key, val in scalar_dict.items(): 255 | self.summary_writer.add_scalar(key, val, iteration) 256 | 257 | def add_histogram_for_model(self, model, iteration): 258 | if not self._should_log_tensorboard(): 259 | return 260 | 261 | for name, param in model.named_parameters(): 262 | np_param = param.clone().cpu().data.numpy() 263 | self.summary_writer.add_histogram(name, np_param, iteration) 264 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | """ Optimizer Factory w/ Custom Weight Decay 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import re 5 | import torch 6 | from torch import optim as optim 7 | from utils.distributed import is_main_process 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | try: 11 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 12 | has_apex = True 13 | except ImportError: 14 | has_apex = False 15 | 16 | 17 | def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): 18 | named_param_tuples = [] 19 | for name, param in model.named_parameters(): 20 | if not param.requires_grad: 21 | continue # frozen weights 22 | if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): 23 | named_param_tuples.append([name, param, 0]) 24 | elif name in no_decay_list: 25 | named_param_tuples.append([name, param, 0]) 26 | else: 27 | named_param_tuples.append([name, param, weight_decay]) 28 | return named_param_tuples 29 | 30 | 31 | def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): 32 | """use lr=diff_lr for modules named found in diff_lr_names, 33 | otherwise use lr=default_lr 34 | 35 | Args: 36 | named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module 37 | diff_lr_names: List(str) 38 | diff_lr: float 39 | default_lr: float 40 | Returns: 41 | named_param_tuples_with_lr: List([name, param, weight_decay, lr]) 42 | """ 43 | named_param_tuples_with_lr = [] 44 | logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") 45 | for name, p, wd in named_param_tuples_or_model: 46 | use_diff_lr = False 47 | for diff_name in diff_lr_names: 48 | # if diff_name in name: 49 | if re.search(diff_name, name) is not None: 50 | logger.info(f"param {name} use different_lr: {diff_lr}") 51 | use_diff_lr = True 52 | break 53 | 54 | named_param_tuples_with_lr.append( 55 | [name, p, wd, diff_lr if use_diff_lr else default_lr] 56 | ) 57 | 58 | if is_main_process(): 59 | for name, _, wd, diff_lr in named_param_tuples_with_lr: 60 | logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") 61 | 62 | return named_param_tuples_with_lr 63 | 64 | 65 | def create_optimizer_params_group(named_param_tuples_with_lr): 66 | """named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" 67 | group = {} 68 | for name, p, wd, lr in named_param_tuples_with_lr: 69 | if wd not in group: 70 | group[wd] = {} 71 | if lr not in group[wd]: 72 | group[wd][lr] = [] 73 | group[wd][lr].append(p) 74 | 75 | optimizer_params_group = [] 76 | for wd, lr_groups in group.items(): 77 | for lr, p in lr_groups.items(): 78 | optimizer_params_group.append(dict( 79 | params=p, 80 | weight_decay=wd, 81 | lr=lr 82 | )) 83 | logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") 84 | return optimizer_params_group 85 | 86 | 87 | def create_optimizer(args, model, filter_bias_and_bn=True): 88 | opt_lower = args.opt.lower() 89 | weight_decay = args.weight_decay 90 | # check for modules that requires different lr 91 | if hasattr(args, "different_lr") and args.different_lr.enable: 92 | diff_lr_module_names = args.different_lr.module_names 93 | diff_lr = args.different_lr.lr 94 | else: 95 | diff_lr_module_names = [] 96 | diff_lr = None 97 | 98 | no_decay = {} 99 | if hasattr(model, 'no_weight_decay'): 100 | no_decay = model.no_weight_decay() 101 | named_param_tuples = add_weight_decay( 102 | model, weight_decay, no_decay, filter_bias_and_bn) 103 | named_param_tuples = add_different_lr( 104 | named_param_tuples, diff_lr_module_names, diff_lr, args.lr) 105 | parameters = create_optimizer_params_group(named_param_tuples) 106 | 107 | if 'fused' in opt_lower: 108 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 109 | 110 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 111 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 112 | opt_args['eps'] = args.opt_eps 113 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 114 | opt_args['betas'] = args.opt_betas 115 | if hasattr(args, 'opt_args') and args.opt_args is not None: 116 | opt_args.update(args.opt_args) 117 | 118 | opt_split = opt_lower.split('_') 119 | opt_lower = opt_split[-1] 120 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 121 | opt_args.pop('eps', None) 122 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 123 | elif opt_lower == 'momentum': 124 | opt_args.pop('eps', None) 125 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 126 | elif opt_lower == 'adam': 127 | optimizer = optim.Adam(parameters, **opt_args) 128 | elif opt_lower == 'adamw': 129 | optimizer = optim.AdamW(parameters, **opt_args) 130 | else: 131 | assert False and "Invalid optimizer" 132 | raise ValueError 133 | return optimizer 134 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from torch.optim import Optimizer 5 | import math 6 | from torch.optim.lr_scheduler import LambdaLR 7 | 8 | 9 | def create_scheduler(args, optimizer): 10 | lr_scheduler = None 11 | if args.sched == 'cosine': 12 | lr_scheduler = get_cosine_schedule_with_warmup( 13 | optimizer, 14 | num_warmup_steps=args.num_warmup_steps, 15 | num_training_steps=args.num_training_steps, 16 | num_cycles=0.5, 17 | min_lr_multi=args.min_lr_multi 18 | ) 19 | return lr_scheduler 20 | 21 | 22 | def get_cosine_schedule_with_warmup( 23 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 24 | num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1 25 | ): 26 | """ 27 | Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py 28 | 29 | Create a schedule with a learning rate that decreases following the values of the cosine function between the 30 | initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the 31 | initial lr set in the optimizer. 32 | Args: 33 | optimizer ([`~torch.optim.Optimizer`]): 34 | The optimizer for which to schedule the learning rate. 35 | num_warmup_steps (`int`): 36 | The number of steps for the warmup phase. 37 | num_training_steps (`int`): 38 | The total number of training steps. 39 | num_cycles (`float`, *optional*, defaults to 0.5): 40 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 41 | following a half-cosine). 42 | min_lr_multi (`float`, *optional*, defaults to 0): 43 | The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi. 44 | last_epoch (`int`, *optional*, defaults to -1): 45 | The index of the last epoch when resuming training. 46 | Return: 47 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 48 | """ 49 | 50 | def lr_lambda(current_step): 51 | if current_step < num_warmup_steps: 52 | return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps))) 53 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 54 | return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 55 | 56 | return LambdaLR(optimizer, lr_lambda, last_epoch) 57 | --------------------------------------------------------------------------------