├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE ├── README.md ├── SECURITY.md ├── config_release ├── base_model.json ├── didemo_ret.json ├── msrvtt_qa.json ├── msrvtt_ret.json ├── msvd_qa.json ├── pretrain_alpro.json ├── pretrain_prompter.json ├── timesformer_divst_8x32_224_k600.json └── timesformer_divst_8x32_224_k600_gc.json ├── env ├── install_pkg.sh └── requirements.txt ├── pics └── teaser.jpg ├── run_scripts ├── clear_cuda_cache.sh ├── ft_didemo_ret.sh ├── ft_msrvtt_qa.sh ├── ft_msrvtt_ret.sh ├── ft_msvd_qa.sh ├── inf_didemo_ret.sh ├── inf_msrvtt_qa.sh ├── inf_msrvtt_ret.sh ├── inf_msvd_qa.sh ├── pt_alpro.sh └── pt_prompter.sh └── src ├── __init__.py ├── configs └── config.py ├── datasets ├── data_utils.py ├── dataloader.py ├── dataset_base.py ├── dataset_pretrain_sparse.py ├── dataset_video_qa.py ├── dataset_video_retrieval.py └── randaugment.py ├── modeling ├── alpro_models.py ├── timesformer │ ├── __init__.py │ ├── conv2d_same.py │ ├── features.py │ ├── helpers.py │ ├── linear.py │ ├── operators.py │ ├── vit.py │ └── vit_utils.py ├── transformers.py └── xbert.py ├── optimization ├── adamw.py ├── sched.py └── utils.py ├── pretrain ├── run_pretrain_contrastive_only.py └── run_pretrain_sparse.py ├── tasks ├── run_video_qa.py └── run_video_retrieval.py └── utils ├── basic_utils.py ├── distributed.py ├── grad_ckpt.py ├── load_save.py ├── logger.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | # script 4 | tmp_all/script/ 5 | 6 | # Philly-realted # 7 | pt/ 8 | .ptconfig 9 | 10 | 11 | 12 | # Project-related # 13 | */*results*/ 14 | *results*/ 15 | tmp*/ 16 | cache/* 17 | */cache*/ 18 | tmp*.py 19 | *pickle 20 | 21 | # compiled files # 22 | *.pyc 23 | 24 | # Packages # 25 | ############ 26 | # it's better to unpack these files and commit the raw source 27 | # git has its own built in compression methods 28 | *.7z 29 | *.dmg 30 | *.gz 31 | *.iso 32 | *.jar 33 | *.rar 34 | *.tar 35 | *.zip 36 | 37 | # Logs and databases # 38 | ###################### 39 | *.log 40 | *.sql 41 | *.sqlite 42 | .ipynb_checkpoints/ 43 | *.swp 44 | *.vscode/ 45 | *.idea/ 46 | 47 | # OS generated files # 48 | ###################### 49 | .DS_Store 50 | .DS_Store? 51 | ._* 52 | .Spotlight-V100 53 | .Trashes 54 | ehthumbs.db 55 | Thumbs.db 56 | 57 | # project-specific 58 | img 59 | txt 60 | ext 61 | data 62 | output 63 | src/configs_local 64 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | 107 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALPRO (CVPR 22') 2 | 3 | ## ALPRO is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS), a one-stop library for language-vision intelligence! 4 | 5 | ## Align and Prompt: Video-and-Language Pre-training with Entity Prompts [[Paper](https://arxiv.org/abs/2112.09583)] 6 | 7 | [Dongxu Li](https://www.linkedin.com/in/dongxu-li-a8a035110/), [Junnan Li](https://sites.google.com/site/junnanlics), [Hongdong Li](http://users.cecs.anu.edu.au/~hongdong/), [Juan Carlos Niebles](http://www.niebles.net/), [Steven C.H. Hoi](https://sites.google.com/view/stevenhoi/home) 8 | 9 | 10 | 11 | Official PyTorch code for ALPRO. This repository supports pre-training as well as finetuning on 12 | - Text-Video Retrieval on MSRVTT and DiDeMo. 13 | - Video Question Anwsering on MSRVTT and MSVD. 14 | 15 | ## Requirements 16 | Our implementation is tested on Ubuntu 20.04.1 with NVIDIA A100 GPUs. Supports for other platforms and hardwares are possible with no warrant. To install the required packages: 17 | 18 | ```bash 19 | cd env && bash install_pkg.sh 20 | ``` 21 | 22 | ## Data Preparation 23 | 1. Download Annotations and Pre-trained Checkpoints 24 | - [Text annotations](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/data.zip) 25 | - [Checkpoints of pre-trained model and finetuned model](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/output.zip) 26 | - [Externel resources](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/ext.zip) 27 | - unzip `data.zip`, `output.zip`, `ext.zip` under `ALPRO/`. 28 | 29 | 2. Download raw videos of downstream datasets. 30 | - MSRVTT: 31 | - download train_val_videos.zip and test_videos.zip from e.g. [here](https://www.mediafire.com/folder/h14iarbs62e7p/shared). 32 | - check md5sum: 33 | 34 | ```bash 35 | 51f2394d279cf84f1642defd9a651e6f train_val_videos.zip 36 | 0af68454cec9d586e92805739f3911d0 test_videos.zip 37 | ``` 38 | - unzip all the videos into `data/msrvtt_ret/videos` (10k in total). 39 | - create the following soft link: 40 | 41 | ```bash 42 | ln -s data/msrvtt_ret/videos data/msrvtt_qa/videos``` 43 | - MSVD: 44 | - download from official release: 45 | 46 | ```bash 47 | wget -nc https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar 48 | ``` 49 | - check md5sum: 50 | 51 | ```bash 52 | 9bdb20fcf14d59524a6febca9f6a8d89 YouTubeClips.tar 53 | ``` 54 | - unzip all the videos to `data/msvd_qa/videos` (1,970 videos in total). 55 | 56 | ```bash 57 | mkdir data/msvd_qa/videos/ 58 | tar xvf YouTubeClips.tar -C data/msvd_qa/videos --strip-components=1 59 | ``` 60 | - DiDeMo: 61 | - Following [instructions](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md) and download from the official release [here](https://drive.google.com/drive/u/1/folders/1_oyJ5rQiZboipbMl6tkhY8v0s9zDkvJc); 62 | - unzip all the videos into `data/didemo_ret/videos`. 63 | - Note there might be a couple videos missing. See [here](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md#getting-the-videos) to download. However, as they account for a small portion of training set, you may feel safe to ignore. 64 | - Convert all the DiDeMo videos into `*.mp4` format using e.g. [`ffmpeg`](https://askubuntu.com/questions/396883/how-to-simply-convert-video-files-i-e-mkv-to-mp4). 65 | - We obtained 10,463 videos following these steps (with one video `77807177@N00_5753455690_1e04ccb364` missing). 66 | 67 | 68 | 69 | 3. The directory is expected to be in the structure below: 70 | ```bash 71 | . 72 | |-config_release # configuration files 73 | |-data # text annotations and raw videos 74 | |---didemo_ret 75 | |-----txt 76 | |-----videos 77 | |---msrvtt_qa/... 78 | |---msrvtt_ret/... 79 | |---msvd_qa/... 80 | |-env # scripts to install packages 81 | |-ext # external resources, e.g. bert tokenizer 82 | |-output # checkpoints for pre-trained/finetuned models 83 | |---downstreams 84 | |-----didemo_ret 85 | |-------public 86 | |---------ckpt # official finetuned checkpoints 87 | |---------log # inference log 88 | |---------results_test 89 | |-----------step_best_1_mean 90 | |-----msrvtt_qa/... 91 | |-----msrvtt_ret/... 92 | |-----msvd_qa/... 93 | |-run_scripts # bash scripts to launch experiments 94 | |-src # source code 95 | ``` 96 | 97 | ## Inference with Official Checkpoints 98 | 99 | ```bash 100 | cd run_scripts 101 | bash inf_msrvtt_ret.sh 102 | # {'text2video': {'r1': 33.9, 'r5': 60.7, 'r10': 73.2, 'medianR': 3.0, 'meanR': 27.404}} 103 | bash inf_didemo_ret.sh 104 | # {'text2video': {'r1': 35.9, 'r5': 67.5, 'r10': 78.8, 'medianR': 3.0, 'meanR': 19.125}} 105 | bash inf_msrvtt_qa.sh 106 | # {'ratios': {'what_ratio': [68.48, 49872], 'who_ratio': [27.99, 20385], 'how_ratio': [2.25, 1640], 'where_ratio': [0.34, 250], 'when_ratio': [0.93, 677]}, 'overall_acc': 42.12, 'what_acc': 36.05, 'who_acc': 52.24, 'how_acc': 85.67, 'where_acc': 42.8, 'when_acc': 78.88} 107 | bash inf_msvd_qa.sh 108 | # {'ratios': {'what_ratio': [61.93, 8150], 'who_ratio': [34.6, 4554], 'how_ratio': [2.81, 370], 'where_ratio': [0.21, 28], 'when_ratio': [0.44, 58]}, 'overall_acc': 45.91, 'what_acc': 37.02, 'who_acc': 58.59, 'how_acc': 81.62, 'where_acc': 46.43, 'when_acc': 72.41} 109 | ``` 110 | 111 | 112 | ## Downstream Task Finetuning 113 | - To finetune on downstream tasks with the pre-trained checkpoint `output/pretrain/alpro_pretrained_ckpt.pt` 114 | 115 | ```bash 116 | cd run_scripts 117 | bash ft_msrvtt_ret.sh 118 | bash ft_didemo_ret.sh 119 | bash ft_msrvtt_qa.sh 120 | bash ft_msvd_qa.sh 121 | ``` 122 | 123 | For example, with MSRVTT retrieval: 124 | ```bash 125 | cd ALPRO/ 126 | 127 | export PYTHONPATH="$PYTHONPATH:$PWD" 128 | echo $PYTHONPATH 129 | 130 | CONFIG_PATH='config_release/msrvtt_ret.json' 131 | 132 | horovodrun -np 8 python src/tasks/run_video_retrieval.py \ # change -np to GPUs numbers. 133 | --config $CONFIG_PATH \ 134 | --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S') # change to your local path to store finetuning ckpts and logs 135 | ``` 136 | - Run inference with locally-finetuned checkpoints. 137 | ```bash 138 | cd ALPRO/ 139 | 140 | export PYTHONPATH="$PYTHONPATH:$PWD" 141 | echo $PYTHONPATH 142 | 143 | STEP='best' 144 | 145 | CONFIG_PATH='config_release/msrvtt_ret.json' 146 | OUTPUT_DIR='[INPUT_YOUR_OUTPUT_PATH_HERE]' 147 | 148 | TXT_DB='data/msrvtt_ret/txt/test.jsonl' 149 | IMG_DB='data/msrvtt_ret/videos' 150 | 151 | horovodrun -np 8 python src/tasks/run_video_retrieval.py \ 152 | --do_inference 1 \ 153 | --inference_split test \ 154 | --inference_model_step $STEP \ 155 | --inference_txt_db $TXT_DB \ 156 | --inference_img_db $IMG_DB \ 157 | --inference_batch_size 64 \ 158 | --output_dir $OUTPUT_DIR \ 159 | --config $CONFIG_PATH 160 | ``` 161 | - `OUTPUT_DIR` is the path after the `--output_dir` option in the finetuning script. 162 | - `$STEP` is a string, which tells the script to use the checkpoint `$OUTPUT_DIR/ckpt/model_step_$STEP.pt` for inference. 163 | 164 | 165 | ## Pretraining 166 | 1. Download [WebVid2M](https://github.com/m-bain/frozen-in-time) and [CC-3M](https://github.com/igorbrigadir/DownloadConceptualCaptions). 167 | 168 | - Put WebVid2M videos under `data/webvid2m`; 169 | - 💡 we downsample webvid2m videos to 10% of the original FPS to speed-up video loading; 170 | - change `data/cc3m/txt/cc3m.json` with local image paths. 171 | 172 | 2. Training Prompter: 173 | ```bash 174 | cd run_scripts && bash pt_prompter.sh 175 | ``` 176 | 177 | 3. Training video-language model: 178 | ```bash 179 | cd run_scripts && bash pt_alpro.sh 180 | ``` 181 | If you would like to use custom prompter weight, please change `teacher_weights_path` in `config_release/pretrain_alpro.json` 182 | 4. To finetune with pre-trained checkpoints, please change `e2e_weights_path` in the finetuning config files, e.g. `config_release/msrvtt_ret.json`. 183 | 184 | 185 | ## Citation 186 | 187 | If you find ALPRO useful for your research, please consider citing: 188 | ```bibtex 189 | @inproceedings{li2021align, 190 | title={Align and Prompt: Video-and-Language Pre-training with Entity Prompts}, 191 | author={Dongxu Li, Junnan Li, Hongdong Li, Juan Carlos Niebles, Steven C.H. Hoi}, 192 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 193 | year={2022} 194 | } 195 | ``` 196 | 197 | ## Acknowledgement 198 | We thank members at Salesforce Research for their helpful discussions. 199 | 200 | The implementation of ALPRO relies on resources from [ClipBERT](https://github.com/jayleicn/ClipBERT), 201 | [transformers](https://github.com/huggingface/transformers), 202 | [TimeSformer](https://github.com/facebookresearch/TimeSformer/tree/main/timesformer/models), 203 | The code is implemented using [PyTorch](https://github.com/pytorch/pytorch), 204 | with multi-GPU support from [Horovod](https://github.com/horovod/horovod) and [gradient-checkpoint](https://github.com/csrhddlam/pytorch-checkpoint). We thank the original authors for their open-sourcing and encourage ALPRO users to cite their works when applicable. 205 | 206 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /config_release/base_model.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "layer_norm_eps": 1e-12, 9 | "max_position_embeddings": 512, 10 | "model_type": "bert", 11 | "num_attention_heads": 12, 12 | "num_hidden_layers": 12, 13 | "pad_token_id": 0, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522, 16 | "fusion_layer": 6, 17 | "encoder_width": 768, 18 | "itc_token_type": "cls" 19 | } 20 | -------------------------------------------------------------------------------- /config_release/didemo_ret.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets": [ 3 | { 4 | "name": "didemo", 5 | "txt": "data/didemo_ret/txt/train.jsonl", 6 | "img": "data/didemo_ret/videos" 7 | } 8 | ], 9 | "val_datasets": [ 10 | { 11 | "name": "didemo_retrieval", 12 | "txt": "data/didemo_ret/txt/val.jsonl", 13 | "img": "data/didemo_ret/videos" 14 | } 15 | ], 16 | "max_txt_len": 50, 17 | "crop_img_size": 224, 18 | "resize_size": 256, 19 | "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 20 | "img_pixel_std": [0.26862954, 0.26130258, 0.27577711], 21 | "img_input_format": "RGB", 22 | "num_frm": 8, 23 | "train_n_clips": 1, 24 | "max_n_example_per_group": 1, 25 | "model_config": "config_release/base_model.json", 26 | "tokenizer_dir": "ext/bert-base-uncased/", 27 | "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json", 28 | "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt", 29 | "bert_weights_path": null, 30 | "train_batch_size": 12, 31 | "val_batch_size": 12, 32 | "gradient_accumulation_steps": 1, 33 | "num_train_epochs": 10, 34 | "min_valid_steps": 20, 35 | "num_valid": 20, 36 | "learning_rate": 4e-5, 37 | "weight_decay": 1e-3, 38 | "decay": "linear", 39 | "optim": "adamw", 40 | "betas": [0.9, 0.98], 41 | "dropout": 0.1, 42 | "grad_norm": 20.0, 43 | "seed":42, 44 | "fp16": 0, 45 | "num_workers": 4 46 | } 47 | -------------------------------------------------------------------------------- /config_release/msrvtt_qa.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets": [ 3 | { 4 | "name": "msrvtt_qa", 5 | "txt": { 6 | "msrvtt_qa": "data/msrvtt_qa/txt/train.jsonl" 7 | }, 8 | "img": "data/msrvtt_qa/videos" 9 | } 10 | ], 11 | "val_datasets": [ 12 | { 13 | "name": "msrvtt_qa", 14 | "txt": { 15 | "msrvtt_qa": "data/msrvtt_qa/txt/val.jsonl" 16 | }, 17 | "img": "data/msrvtt_qa/videos" 18 | } 19 | ], 20 | "ans2label_path": "data/msrvtt_qa/txt/train_ans2label.json", 21 | "max_txt_len": 40, 22 | "crop_img_size": 224, 23 | "resize_size": 256, 24 | "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 25 | "img_pixel_std": [0.26862954, 0.26130258, 0.27577711], 26 | "img_input_format": "RGB", 27 | "train_n_clips": 1, 28 | "model_config": "config_release/base_model.json", 29 | "tokenizer_dir": "ext/bert-base-uncased/", 30 | "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600_gc.json", 31 | "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt", 32 | "num_frm": 16, 33 | "train_batch_size": 12, 34 | "val_batch_size": 12, 35 | "gradient_accumulation_steps": 2, 36 | "num_train_epochs": 10, 37 | "min_valid_steps": 50, 38 | "num_valid": 50, 39 | "learning_rate": 5e-5, 40 | "weight_decay": 1e-3, 41 | "decay": "linear", 42 | "optim": "adamw", 43 | "betas": [0.9, 0.98], 44 | "dropout": 0.1, 45 | "grad_norm": 5.0, 46 | "cnn_lr_decay": "linear", 47 | "seed":42, 48 | "fp16": 0, 49 | "classifier": "mlp", 50 | "cls_hidden_scale": 2, 51 | "task": "msrvtt_qa", 52 | "num_workers": 4 53 | } 54 | -------------------------------------------------------------------------------- /config_release/msrvtt_ret.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets": [ 3 | { 4 | "name": "msrvtt", 5 | "txt": "data/msrvtt_ret/txt/train.jsonl", 6 | "img": "data/msrvtt_ret/videos" 7 | } 8 | ], 9 | "val_datasets": [ 10 | { 11 | "name": "msrvtt_retrieval", 12 | "txt": "data/msrvtt_ret/txt/val.jsonl", 13 | "img": "data/msrvtt_ret/videos" 14 | } 15 | ], 16 | "max_txt_len": 40, 17 | "crop_img_size": 224, 18 | "resize_size": 256, 19 | "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 20 | "img_pixel_std": [0.26862954, 0.26130258, 0.27577711], 21 | "img_input_format": "RGB", 22 | "train_n_clips": 1, 23 | "model_config": "config_release/base_model.json", 24 | "tokenizer_dir": "ext/bert-base-uncased/", 25 | "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json", 26 | "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt", 27 | "num_frm": 8, 28 | "train_batch_size": 8, 29 | "val_batch_size": 8, 30 | "gradient_accumulation_steps": 1, 31 | "num_train_epochs": 5, 32 | "min_valid_steps": 100, 33 | "num_valid": 20, 34 | "learning_rate": 2.5e-5, 35 | "weight_decay": 1e-3, 36 | "decay": "linear", 37 | "optim": "adamw", 38 | "betas": [0.9, 0.98], 39 | "dropout": 0.1, 40 | "grad_norm": 5.0, 41 | "seed":42, 42 | "fp16": 0, 43 | "num_workers": 4 44 | } 45 | -------------------------------------------------------------------------------- /config_release/msvd_qa.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets": [ 3 | { 4 | "name": "msvd_qa", 5 | "txt": { 6 | "msvd_qa": "data/msvd_qa/txt/train.jsonl" 7 | }, 8 | "img": "data/msvd_qa/videos" 9 | } 10 | ], 11 | "val_datasets": [ 12 | { 13 | "name": "msvd_qa", 14 | "txt": { 15 | "msvd_qa": "data/msvd_qa/txt/val.jsonl" 16 | }, 17 | "img": "data/msvd_qa/videos" 18 | } 19 | ], 20 | "ans2label_path": "data/msvd_qa/txt/train_ans2label.json", 21 | "num_labels": 2423, 22 | "max_txt_len": 40, 23 | "crop_img_size": 224, 24 | "resize_size": 256, 25 | "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 26 | "img_pixel_std": [0.26862954, 0.26130258, 0.27577711], 27 | "img_input_format": "RGB", 28 | "train_n_clips": 1, 29 | "num_frm": 16, 30 | "model_config": "config_release/base_model.json", 31 | "tokenizer_dir": "ext/bert-base-uncased/", 32 | "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600_gc.json", 33 | "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt", 34 | "train_batch_size": 12, 35 | "val_batch_size": 12, 36 | "gradient_accumulation_steps": 2, 37 | "num_train_epochs": 15, 38 | "min_valid_steps": 50, 39 | "num_valid": 30, 40 | "learning_rate": 5e-5, 41 | "weight_decay": 1e-3, 42 | "decay": "linear", 43 | "optim": "adamw", 44 | "betas": [0.9, 0.98], 45 | "dropout": 0.1, 46 | "grad_norm": 20.0, 47 | "cnn_lr_decay": "linear", 48 | "seed":42, 49 | "fp16": 0, 50 | "save_steps_ratio": 0.05, 51 | "classifier": "mlp", 52 | "cls_hidden_scale": 2, 53 | "task": "msvd_qa", 54 | "num_workers": 4 55 | } 56 | -------------------------------------------------------------------------------- /config_release/pretrain_alpro.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets": [ 3 | { 4 | "name": "webvid2m", 5 | "ann": "data/webvid2m/txt/train.pkl", 6 | "txt": null, 7 | "img": "data/webvid2m/videos" 8 | }, 9 | { 10 | "name": "cc3m", 11 | "ann": "data/cc3m/txt/cc3m.json", 12 | "txt": null, 13 | "img": null 14 | } 15 | ], 16 | "val_datasets": [ 17 | { 18 | "name": "webvid2m", 19 | "ann": "data/webvid2m/txt/val.pkl", 20 | "txt": null, 21 | "img": "data/webvid2m/videos" 22 | } 23 | ], 24 | "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 25 | "img_pixel_std": [0.26862954, 0.26130258, 0.27577711], 26 | "img_input_format": "RGB", 27 | "model_type": "pretrain", 28 | "model_config": "config_release/base_model.json", 29 | "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json", 30 | "visual_weights_path": "vit_base_patch16_224", 31 | "teacher_weights_path": "output/pretrain/prompter_pretrained.pt", 32 | "entity_file_path": "data/unigrams.txt", 33 | "tokenizer_dir": "ext/bert-base-uncased/", 34 | "max_txt_len": 30, 35 | "crop_img_size": 224, 36 | "resize_size": 256, 37 | "train_batch_size": 16, 38 | "val_batch_size": 16, 39 | "gradient_accumulation_steps": 1, 40 | "num_train_epochs": 10, 41 | "min_valid_steps": 10, 42 | "num_valid": 10, 43 | "learning_rate": 1e-4, 44 | "decay": "linear", 45 | "optim": "adamw", 46 | "betas": [0.9, 0.98], 47 | "dropout": 0.1, 48 | "weight_decay": 1e-3, 49 | "grad_norm": 20.0, 50 | "seed":42, 51 | "fp16": 0, 52 | "use_itm": 1, 53 | "use_mlm": 1, 54 | "use_itc": 1, 55 | "use_mpm": 1, 56 | "n_workers": 4, 57 | "save_steps_ratio": 0.01, 58 | "frm_sampling_strategy": "headtail", 59 | "num_frm": 4, 60 | "fps": 0.5, 61 | "debug": false, 62 | "warmup_ratio": 0.05, 63 | "log_interval": 100 64 | } 65 | -------------------------------------------------------------------------------- /config_release/pretrain_prompter.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_datasets": [ 3 | { 4 | "name": "webvid2m", 5 | "ann": "data/webvid2m/txt/train.pkl", 6 | "txt": null, 7 | "img": "data/webvid2m/videos" 8 | }, 9 | { 10 | "name": "cc3m", 11 | "ann": "data/cc3m/txt/cc3m.json", 12 | "txt": null, 13 | "img": null 14 | } 15 | ], 16 | "val_datasets": [ 17 | { 18 | "name": "webvid2m", 19 | "ann": "data/webvid2m/txt/val.pkl", 20 | "txt": null, 21 | "img": "data/webvid2m/videos" 22 | } 23 | ], 24 | "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 25 | "img_pixel_std": [0.26862954, 0.26130258, 0.27577711], 26 | "img_input_format": "RGB", 27 | "model_type": "pretrain", 28 | "model_config": "config_release/base_model.json", 29 | "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json", 30 | "visual_weights_path": "vit_base_patch16_224", 31 | "tokenizer_dir": "ext/bert-base-uncased/", 32 | "max_txt_len": 30, 33 | "crop_img_size": 224, 34 | "resize_size": 256, 35 | "train_batch_size": 16, 36 | "val_batch_size": 16, 37 | "gradient_accumulation_steps": 2, 38 | "num_train_epochs": 10, 39 | "min_valid_steps": 100, 40 | "num_valid": 10, 41 | "learning_rate": 1e-4, 42 | "decay": "linear", 43 | "optim": "adamw", 44 | "betas": [0.9, 0.98], 45 | "dropout": 0.1, 46 | "weight_decay": 1e-3, 47 | "grad_norm": 20.0, 48 | "seed":42, 49 | "fp16": 0, 50 | "use_itm": 0, 51 | "use_mlm": 0, 52 | "use_itc": 1, 53 | "n_workers": 4, 54 | "save_steps_ratio": 0.05, 55 | "frm_sampling_strategy": "headtail", 56 | "num_frm": 4, 57 | "debug": false, 58 | "warmup_ratio": 0.05, 59 | "log_interval": 100 60 | } 61 | -------------------------------------------------------------------------------- /config_release/timesformer_divst_8x32_224_k600.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls": "TimeSformer", 3 | "patch_size": 16, 4 | "attn_drop_rate": 0, 5 | "drop_rate": 0, 6 | "drop_path_rate": 0.1, 7 | "maxpool_kernel_size": 2, 8 | "use_maxpooling": false, 9 | "gradient_checkpointing": false 10 | } 11 | -------------------------------------------------------------------------------- /config_release/timesformer_divst_8x32_224_k600_gc.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls": "TimeSformer", 3 | "patch_size": 16, 4 | "attn_drop_rate": 0, 5 | "drop_rate": 0, 6 | "drop_path_rate": 0.1, 7 | "maxpool_kernel_size": 2, 8 | "use_maxpooling": false, 9 | "gradient_checkpointing": true 10 | } 11 | -------------------------------------------------------------------------------- /env/install_pkg.sh: -------------------------------------------------------------------------------- 1 | apt update 2 | apt install lsof 3 | 4 | # horovod 5 | HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \ 6 | pip install --no-cache-dir horovod==0.19.4 &&\ 7 | ldconfig 8 | 9 | # use the faster pillow-simd instead of the original pillow 10 | # https://github.com/uploadcare/pillow-simd 11 | pip uninstall pillow && \ 12 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 13 | 14 | spacy download en 15 | 16 | pip install -r requirements.txt 17 | 18 | git clone https://github.com/NVIDIA/apex.git &&\ 19 | cd apex &&\ 20 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . &&\ 21 | rm -rf ../apex 22 | 23 | -------------------------------------------------------------------------------- /env/requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb 2 | joblib 3 | cytoolz 4 | lz4==2.1.9 5 | lmdb==0.97 6 | msgpack-numpy 7 | msgpack 8 | toolz 9 | transformers==4.11.3 10 | tensorboard 11 | tqdm 12 | easydict 13 | pycocotools>=2.0.1 14 | opencv-python 15 | tensorboardX==2.0 16 | av==8.0.2 17 | ujson 18 | einops 19 | decord 20 | timm 21 | -------------------------------------------------------------------------------- /pics/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALPRO/d21173f55a73b922d8ce2b06a05f78623b344fe7/pics/teaser.jpg -------------------------------------------------------------------------------- /run_scripts/clear_cuda_cache.sh: -------------------------------------------------------------------------------- 1 | for i in $(lsof /dev/nvidia* | grep python | awk '{print $2}' | sort -u); do kill -9 $i; done 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /run_scripts/ft_didemo_ret.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | CONFIG_PATH='config_release/didemo_ret.json' 7 | 8 | horovodrun -np 8 python src/tasks/run_video_retrieval.py \ 9 | --config $CONFIG_PATH \ 10 | --output_dir /export/home/workspace/experiments/alpro/finetune/didemo_ret/$(date '+%Y%m%d%H%M%S') 11 | -------------------------------------------------------------------------------- /run_scripts/ft_msrvtt_qa.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | CONFIG_PATH='config_release/msrvtt_qa.json' 7 | 8 | horovodrun -np 8 python src/tasks/run_video_qa.py \ 9 | --config $CONFIG_PATH \ 10 | --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_qa/$(date '+%Y%m%d%H%M%S') 11 | -------------------------------------------------------------------------------- /run_scripts/ft_msrvtt_ret.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | CONFIG_PATH='config_release/msrvtt_ret.json' 7 | 8 | horovodrun -np 8 python src/tasks/run_video_retrieval.py \ 9 | --config $CONFIG_PATH \ 10 | --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S') -------------------------------------------------------------------------------- /run_scripts/ft_msvd_qa.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | CONFIG_PATH='config_release/msvd_qa.json' 7 | 8 | horovodrun -np 8 python src/tasks/run_video_qa.py \ 9 | --config $CONFIG_PATH \ 10 | --output_dir /export/home/workspace/experiments/alpro/finetune/msvd_qa/$(date '+%Y%m%d%H%M%S') 11 | -------------------------------------------------------------------------------- /run_scripts/inf_didemo_ret.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | STEP='best' 7 | 8 | CONFIG_PATH='config_release/didemo_ret.json' 9 | 10 | TXT_DB='data/didemo_ret/txt/test.jsonl' 11 | IMG_DB='data/didemo_ret/videos' 12 | 13 | horovodrun -np 8 python src/tasks/run_video_retrieval.py \ 14 | --do_inference 1 \ 15 | --inference_split test \ 16 | --inference_model_step $STEP \ 17 | --inference_txt_db $TXT_DB \ 18 | --inference_img_db $IMG_DB \ 19 | --inference_batch_size 64 \ 20 | --output_dir output/downstreams/didemo_ret/public \ 21 | --config $CONFIG_PATH -------------------------------------------------------------------------------- /run_scripts/inf_msrvtt_qa.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | STEP='best' 7 | 8 | CONFIG_PATH='config_release/msrvtt_qa.json' 9 | 10 | TXT_DB='data/msrvtt_qa/txt/test.jsonl' 11 | IMG_DB='data/msrvtt_qa/videos' 12 | 13 | horovodrun -np 8 python src/tasks/run_video_qa.py \ 14 | --do_inference 1 \ 15 | --inference_split test \ 16 | --inference_model_step $STEP \ 17 | --inference_txt_db $TXT_DB \ 18 | --inference_img_db $IMG_DB \ 19 | --inference_batch_size 64 \ 20 | --output_dir output/downstreams/msrvtt_qa/public \ 21 | --config $CONFIG_PATH -------------------------------------------------------------------------------- /run_scripts/inf_msrvtt_ret.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | STEP='best' 7 | 8 | CONFIG_PATH='config_release/msrvtt_ret.json' 9 | 10 | TXT_DB='data/msrvtt_ret/txt/test.jsonl' 11 | IMG_DB='data/msrvtt_ret/videos' 12 | 13 | horovodrun -np 8 python src/tasks/run_video_retrieval.py \ 14 | --do_inference 1 \ 15 | --inference_split test \ 16 | --inference_model_step $STEP \ 17 | --inference_txt_db $TXT_DB \ 18 | --inference_img_db $IMG_DB \ 19 | --inference_batch_size 64 \ 20 | --output_dir output/downstreams/msrvtt_ret/public \ 21 | --config $CONFIG_PATH -------------------------------------------------------------------------------- /run_scripts/inf_msvd_qa.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | STEP='best' 7 | 8 | CONFIG_PATH='config_release/msvd_qa.json' 9 | 10 | TXT_DB='data/msvd_qa/txt/test.jsonl' 11 | IMG_DB='data/msvd_qa/videos' 12 | 13 | horovodrun -np 8 python src/tasks/run_video_qa.py \ 14 | --do_inference 1 \ 15 | --inference_split test \ 16 | --inference_model_step $STEP \ 17 | --inference_txt_db $TXT_DB \ 18 | --inference_img_db $IMG_DB \ 19 | --inference_batch_size 64 \ 20 | --output_dir output/downstreams/msvd_qa/public \ 21 | --config $CONFIG_PATH -------------------------------------------------------------------------------- /run_scripts/pt_alpro.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | CONFIG_PATH='config_release/pretrain_alpro.json' 7 | 8 | horovodrun -np 16 python src/pretrain/run_pretrain_sparse.py \ 9 | --config $CONFIG_PATH \ 10 | --output_dir /export/home/workspace/experiments/alpro/vl/$(date '+%Y%m%d%H%M%S') -------------------------------------------------------------------------------- /run_scripts/pt_prompter.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | export PYTHONPATH="$PYTHONPATH:$PWD" 4 | echo $PYTHONPATH 5 | 6 | CONFIG_PATH='config_release/pretrain_prompter.json' 7 | 8 | horovodrun -np 8 python src/pretrain/run_pretrain_contrastive_only.py \ 9 | --config $CONFIG_PATH \ 10 | --output_dir /export/home/workspace/experiments/alpro/prompter/$(date '+%Y%m%d%H%M%S') -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/ALPRO/d21173f55a73b922d8ce2b06a05f78623b344fe7/src/__init__.py -------------------------------------------------------------------------------- /src/configs/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from UNITER code 3 | """ 4 | import os 5 | import sys 6 | import json 7 | import argparse 8 | 9 | from easydict import EasyDict as edict 10 | 11 | 12 | def parse_with_config(parsed_args): 13 | """This function will set args based on the input config file. 14 | (1) it only overwrites unset parameters, 15 | i.e., these parameters not set from user command line input 16 | (2) it also sets configs in the config file but declared in the parser 17 | """ 18 | # convert to EasyDict object, enabling access from attributes even for nested config 19 | # e.g., args.train_datasets[0].name 20 | args = edict(vars(parsed_args)) 21 | if args.config is not None: 22 | config_args = json.load(open(args.config)) 23 | override_keys = {arg[2:].split("=")[0] for arg in sys.argv[1:] 24 | if arg.startswith("--")} 25 | for k, v in config_args.items(): 26 | if k not in override_keys: 27 | setattr(args, k, v) 28 | del args.config 29 | return args 30 | 31 | 32 | class SharedConfigs(object): 33 | """Shared options for pre-training and downstream tasks. 34 | For each downstream task, implement a get_*_args function, 35 | see `get_pretraining_args()` 36 | 37 | Usage: 38 | >>> shared_configs = SharedConfigs() 39 | >>> pretraining_config = shared_configs.get_pretraining_args() 40 | """ 41 | 42 | def __init__(self, desc="shared config for pretraining and finetuning"): 43 | parser = argparse.ArgumentParser(description=desc) 44 | # debug parameters 45 | parser.add_argument( 46 | "--debug", type=int, choices=[0, 1], default=0, 47 | help="debug mode, output extra info & break all loops." 48 | "0: disable, 1 enable") 49 | parser.add_argument( 50 | "--data_ratio", type=float, default=1.0, 51 | help="portion of train/val exampels to use," 52 | "e.g., overfit a small set of data") 53 | 54 | # Required parameters 55 | parser.add_argument( 56 | "--model_config", type=str, 57 | help="path to model structure config json") 58 | parser.add_argument( 59 | "--tokenizer_dir", type=str, help="path to tokenizer dir") 60 | parser.add_argument( 61 | "--output_dir", type=str, 62 | help="dir to store model checkpoints & training meta.") 63 | 64 | # data preprocessing parameters 65 | parser.add_argument( 66 | "--max_txt_len", type=int, default=20, help="max text #tokens ") 67 | # parser.add_argument( 68 | # "--max_img_size", type=int, default=448, 69 | # help="max image longer side size, shorter side will be padded with zeros") 70 | parser.add_argument( 71 | "--img_pixel_mean", type=float, default=None, 72 | nargs=3, help="image pixel mean") 73 | parser.add_argument( 74 | "--img_pixel_std", type=float, default=None, 75 | nargs=3, help="image pixel std") 76 | parser.add_argument( 77 | "--img_input_format", type=str, default="BGR", 78 | choices=["BGR", "RGB"], help="image input format is BGR for detectron2") 79 | parser.add_argument( 80 | "--max_n_example_per_group", type=int, default=1, 81 | help="max #examples (e.g., captions) paired with each image/video in an input group." 82 | "1: each image is paired with a single sent., equivalent to sample by sent.;" 83 | "X (X>1): each image can be paired with a maximum of X sent.; X>1 can be used " 84 | "to reduce image processing time, including basic transform (resize, etc) and CNN encoding" 85 | ) 86 | # video specific parameters 87 | parser.add_argument("--fps", type=int, default=1, help="video frame rate to use") 88 | parser.add_argument("--num_frm", type=int, default=3, 89 | help="#frames to use per clip -- we first sample a clip from a video, " 90 | "then uniformly sample num_frm from the clip. The length of the clip " 91 | "will be fps * num_frm") 92 | parser.add_argument("--frm_sampling_strategy", type=str, default="rand", 93 | choices=["rand", "uniform", "start", "middle", "end"], 94 | help="see src.datasets.dataset_base.extract_frames_from_video_binary for details") 95 | 96 | # MLL training settings 97 | parser.add_argument("--train_n_clips", type=int, default=3, 98 | help="#clips to sample from each video for MIL training") 99 | parser.add_argument("--score_agg_func", type=str, default="mean", 100 | choices=["mean", "max", "lse"], 101 | help="score (from multiple clips) aggregation function, lse = LogSumExp") 102 | parser.add_argument("--random_sample_clips", type=int, default=1, choices=[0, 1], 103 | help="randomly sample clips for training, otherwise use uniformly sampled clips.") 104 | 105 | # training parameters 106 | parser.add_argument( 107 | "--train_batch_size", default=128, type=int, 108 | help="Single-GPU batch size for training for Horovod.") 109 | parser.add_argument( 110 | "--val_batch_size", default=128, type=int, 111 | help="Single-GPU batch size for validation for Horovod.") 112 | parser.add_argument( 113 | "--gradient_accumulation_steps", type=int, default=1, 114 | help="#updates steps to accumulate before performing a backward/update pass." 115 | "Used to simulate larger batch size training. The simulated batch size " 116 | "is train_batch_size * gradient_accumulation_steps for a single GPU.") 117 | parser.add_argument("--learning_rate", default=5e-5, type=float, 118 | help="initial learning rate.") 119 | parser.add_argument( 120 | "--log_interval", default=500, type=int, 121 | help="record every a few steps on tensorboard.") 122 | parser.add_argument( 123 | "--num_valid", default=20, type=int, 124 | help="Run validation X times during training and checkpoint.") 125 | parser.add_argument( 126 | "--min_valid_steps", default=100, type=int, 127 | help="minimum #steps between two validation runs") 128 | parser.add_argument( 129 | "--save_steps_ratio", default=0.01, type=float, 130 | help="save every 0.01*global steps to resume after preemption," 131 | "not used for checkpointing.") 132 | parser.add_argument("--num_train_epochs", default=10, type=int, 133 | help="Total #training epochs.") 134 | parser.add_argument("--optim", default="adamw", 135 | choices=["adam", "adamax", "adamw"], 136 | help="optimizer") 137 | parser.add_argument("--betas", default=[0.9, 0.98], 138 | nargs=2, help="beta for adam optimizer") 139 | parser.add_argument("--decay", default="linear", 140 | choices=["linear", "invsqrt"], 141 | help="learning rate decay method") 142 | parser.add_argument("--dropout", default=0.1, type=float, 143 | help="tune dropout regularization") 144 | parser.add_argument("--weight_decay", default=1e-3, type=float, 145 | help="weight decay (L2) regularization") 146 | parser.add_argument("--grad_norm", default=2.0, type=float, 147 | help="gradient clipping (-1 for no clipping)") 148 | parser.add_argument( 149 | "--warmup_ratio", default=0.1, type=float, 150 | help="to perform linear learning rate warmup for. (invsqrt decay)") 151 | parser.add_argument("--transformer_lr_mul", default=1.0, type=float, 152 | help="lr_mul for transformer") 153 | parser.add_argument("--step_decay_epochs", type=int, 154 | nargs="+", help="multi_step decay epochs") 155 | # model arch 156 | parser.add_argument( 157 | "--model_type", type=str, default="pretrain", 158 | help="type of e2e model to use. Support only 'pretrain' for now. ") 159 | parser.add_argument( 160 | "--timesformer_model_cfg", type=str, default="", 161 | help="path to timesformer model cfg yaml") 162 | 163 | # checkpoint 164 | parser.add_argument("--e2e_weights_path", type=str, 165 | help="path to e2e model weights") 166 | parser.add_argument( 167 | "--clip_init", default=0, type=int, choices=[0, 1], 168 | help="1 for using clip ckpt for init.") 169 | parser.add_argument("--bert_weights_path", type=str, 170 | help="path to BERT weights, only use for pretraining") 171 | 172 | # inference only, please include substring `inference' 173 | # in the option to avoid been overwrite by loaded options, 174 | # see start_inference() in run_vqa_w_hvd.py 175 | parser.add_argument("--inference_model_step", default=-1, type=str, 176 | help="pretrained model checkpoint step") 177 | parser.add_argument( 178 | "--do_inference", default=0, type=int, choices=[0, 1], 179 | help="perform inference run. 0: disable, 1 enable") 180 | parser.add_argument( 181 | "--inference_split", default="val", 182 | help="For val, the data should have ground-truth associated it." 183 | "For test*, the data comes with no ground-truth.") 184 | parser.add_argument("--inference_txt_db", type=str, 185 | help="path to txt_db file for inference") 186 | parser.add_argument("--inference_img_db", type=str, 187 | help="path to img_db file for inference") 188 | parser.add_argument("--inference_batch_size", type=int, default=64, 189 | help="single-GPU batch size for inference") 190 | parser.add_argument("--inference_n_clips", type=int, default=1, 191 | help="uniformly sample `ensemble_n_clips` clips, " 192 | "each contains `num_frm` frames. When it == 1, " 193 | "use the frm_sampling_strategy to sample num_frm frames." 194 | "When it > 1, ignore frm_sampling_strategy, " 195 | "uniformly sample N clips, each clips num_frm frames.") 196 | 197 | # device parameters 198 | parser.add_argument("--seed", type=int, default=42, 199 | help="random seed for initialization") 200 | parser.add_argument( 201 | "--fp16", type=int, choices=[0, 1], default=0, 202 | help="Use 16-bit float precision instead of 32-bit." 203 | "0: disable, 1 enable") 204 | parser.add_argument("--n_workers", type=int, default=4, 205 | help="#workers for data loading") 206 | parser.add_argument("--pin_mem", type=int, choices=[0, 1], default=1, 207 | help="pin memory. 0: disable, 1 enable") 208 | 209 | # can use config files, will only overwrite unset parameters 210 | parser.add_argument("--config", help="JSON config files") 211 | self.parser = parser 212 | 213 | def parse_args(self): 214 | parsed_args = self.parser.parse_args() 215 | args = parse_with_config(parsed_args) 216 | 217 | # convert to all [0, 1] options to bool, including these task specific ones 218 | zero_one_options = [ 219 | "fp16", "pin_mem", "use_itm", "use_mlm", "use_itc", "debug", #"freeze_cnn", 220 | "do_inference", 221 | ] 222 | for option in zero_one_options: 223 | if hasattr(args, option): 224 | setattr(args, option, bool(getattr(args, option))) 225 | 226 | # basic checks 227 | # This is handled at TrainingRestorer 228 | # if exists(args.output_dir) and os.listdir(args.output_dir): 229 | # raise ValueError(f"Output directory ({args.output_dir}) " 230 | # f"already exists and is not empty.") 231 | if args.step_decay_epochs and args.decay != "multi_step": 232 | Warning( 233 | f"--step_decay_epochs epochs set to {args.step_decay_epochs}" 234 | f"but will not be effective, as --decay set to be {args.decay}") 235 | 236 | assert args.gradient_accumulation_steps >= 1, \ 237 | f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps} " 238 | 239 | assert 1 >= args.data_ratio > 0, \ 240 | f"--data_ratio should be [1.0, 0), but get {args.data_ratio}" 241 | 242 | return args 243 | 244 | def get_sparse_pretraining_args(self): 245 | # pre-training args 246 | self.parser.add_argument( 247 | "--use_itm", type=int, choices=[0, 1], default=0, 248 | help="enable itm loss. 0: disable, 1 enable") 249 | self.parser.add_argument( 250 | "--use_mlm", type=int, choices=[0, 1], default=0, 251 | help="enable mlm loss. 0: disable, 1 enable") 252 | self.parser.add_argument( 253 | "--use_itc", type=int, choices=[0, 1], default=0, 254 | help="enable itc loss. 0: disable, 1 enable") 255 | 256 | # sparse pretraining-specific settings 257 | self.parser.add_argument( 258 | "--crop_img_size", type=int, default=256, 259 | help="crop size during pre-training.") 260 | self.parser.add_argument( 261 | "--resize_size", type=int, default=288, 262 | help="resize frames to square, ignoring aspect ratio.") 263 | 264 | # MPM-specific 265 | self.parser.add_argument( 266 | "--use_mpm", type=int, choices=[0, 1], default=0, 267 | help="enable mpm loss. 0: disable, 1 enable") 268 | self.parser.add_argument("--teacher_weights_path", type=str, 269 | help="path to teacher model weights, only use for pretraining.") 270 | self.parser.add_argument("--entity_file_path", type=str, 271 | help="path to selected NOUN entities.") 272 | self.parser.add_argument( 273 | "--num_entities", type=int, default=1000, 274 | help="maximum entities to consider for pseudo labels.") 275 | 276 | args = self.parse_args() 277 | return args 278 | 279 | def get_video_retrieval_args(self): 280 | self.parser.add_argument("--eval_retrieval_batch_size", type=int, default=256, 281 | help="batch size for retrieval, since each batch will only have one image, " 282 | "retrieval allows larger batch size") 283 | 284 | args = self.parse_args() 285 | return args 286 | 287 | def get_nlvl_args(self): 288 | args = self.parse_args() 289 | 290 | return args 291 | 292 | 293 | def get_vqa_args(self): 294 | self.parser.add_argument("--ans2label_path", type=str, 295 | help="path to {answer: label} file") 296 | self.parser.add_argument("--loss_type", type=str, default="bce", 297 | help="loss type") 298 | self.parser.add_argument("--classifier", type=str, default="mlp", 299 | choices=["mlp", "linear"], 300 | help="classifier type") 301 | self.parser.add_argument( 302 | "--cls_hidden_scale", type=int, default=2, 303 | help="scaler of the intermediate linear layer dimension for mlp classifier") 304 | self.parser.add_argument("--num_labels", type=int, default=3129, 305 | help="#labels/output-dim for classifier") 306 | return self.parse_args() 307 | 308 | def get_video_qa_args(self): 309 | self.parser.add_argument( 310 | "--task", type=str, 311 | choices=["action", "transition", "frameqa", "msrvtt_qa"], 312 | help="TGIF-QA tasks and MSRVTT-QA") 313 | self.parser.add_argument("--loss_type", type=str, default="ce", 314 | help="loss type, will be overwritten later") 315 | self.parser.add_argument("--classifier", type=str, default="mlp", 316 | choices=["mlp", "linear"], 317 | help="classifier type") 318 | self.parser.add_argument( 319 | "--cls_hidden_scale", type=int, default=2, 320 | help="scaler of the intermediate linear layer dimension for mlp classifier") 321 | # for frameQA msrvtt_qa 322 | self.parser.add_argument("--ans2label_path", type=str, default=None, 323 | help="path to {answer: label} file") 324 | 325 | # manually setup config by task type 326 | args = self.parse_args() 327 | if args.max_n_example_per_group != 1: 328 | Warning(f"For TGIF-QA, most GIF is only paired with a single example, no need to" 329 | f"use max_n_example_per_group={args.max_n_example_per_group}" 330 | f"larger than 1. Automatically reset to 1.") 331 | args.max_n_example_per_group = 1 332 | if os.path.exists(args.ans2label_path): 333 | num_answers = len(json.load(open(args.ans2label_path, "r"))) 334 | else: 335 | num_answers = 0 336 | 337 | if args.task in ["msrvtt_qa", "msvd_qa"]: 338 | args.num_labels = max(num_answers, 1500) 339 | args.loss_type = "ce" 340 | else: 341 | raise NotImplementedError 342 | return args 343 | 344 | 345 | shared_configs = SharedConfigs() 346 | -------------------------------------------------------------------------------- /src/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from UNITER codebase 3 | 4 | A meta data loader for sampling from different datasets / training tasks 5 | A prefetch loader to speedup data loading 6 | """ 7 | import random 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from src.utils.distributed import any_broadcast 12 | 13 | 14 | class MetaLoader(object): 15 | """ wraps multiple data loader """ 16 | def __init__(self, loaders, accum_steps=1, distributed=False): 17 | assert isinstance(loaders, dict) 18 | self.name2loader = {} 19 | self.name2iter = {} 20 | self.sampling_pools = [] 21 | n_batches_in_epoch = 0 22 | for n, l in loaders.items(): 23 | if isinstance(l, tuple): 24 | l, r = l 25 | elif isinstance(l, DataLoader): 26 | r = 1 27 | else: 28 | raise ValueError() 29 | n_batches_in_epoch += len(l.dataset) * r / l.batch_size 30 | self.name2loader[n] = l 31 | self.name2iter[n] = iter(l) 32 | self.sampling_pools.extend([n]*r) 33 | self.n_batches_in_epoch = n_batches_in_epoch 34 | self.accum_steps = accum_steps 35 | self.distributed = distributed 36 | self.step = 0 37 | 38 | def __iter__(self): 39 | """ this iterator will run indefinitely """ 40 | task = self.sampling_pools[0] 41 | while True: 42 | if self.step % self.accum_steps == 0: 43 | task = random.choice(self.sampling_pools) 44 | if self.distributed: 45 | # make sure all process is training same task 46 | task = any_broadcast(task, 0) 47 | self.step += 1 48 | iter_ = self.name2iter[task] 49 | try: 50 | batch = next(iter_) 51 | except StopIteration: 52 | iter_ = iter(self.name2loader[task]) 53 | batch = next(iter_) 54 | self.name2iter[task] = iter_ 55 | 56 | yield task, batch 57 | 58 | 59 | def move_to_cuda(batch): 60 | if isinstance(batch, torch.Tensor): 61 | return batch.cuda(non_blocking=True) 62 | elif isinstance(batch, list): 63 | new_batch = [move_to_cuda(t) for t in batch] 64 | elif isinstance(batch, tuple): 65 | new_batch = tuple(move_to_cuda(t) for t in batch) 66 | elif isinstance(batch, dict): 67 | new_batch = {n: move_to_cuda(t) for n, t in batch.items()} 68 | else: 69 | return batch 70 | return new_batch 71 | 72 | 73 | def record_cuda_stream(batch): 74 | if isinstance(batch, torch.Tensor): 75 | batch.record_stream(torch.cuda.current_stream()) 76 | elif isinstance(batch, list) or isinstance(batch, tuple): 77 | for t in batch: 78 | record_cuda_stream(t) 79 | elif isinstance(batch, dict): 80 | for t in batch.values(): 81 | record_cuda_stream(t) 82 | else: 83 | pass 84 | 85 | 86 | class PrefetchLoader(object): 87 | """ 88 | overlap compute and cuda data transfer 89 | (copied and then modified from nvidia apex) 90 | """ 91 | def __init__(self, loader, img_normalize=None): 92 | self.loader = loader 93 | self.stream = torch.cuda.Stream() 94 | self.img_normalize = img_normalize 95 | 96 | def __iter__(self): 97 | loader_it = iter(self.loader) 98 | self.preload(loader_it) 99 | batch = self.next(loader_it) 100 | while batch is not None: 101 | is_tuple = isinstance(batch, tuple) 102 | if is_tuple: 103 | task, batch = batch 104 | batch["visual_inputs"] = batch["visual_inputs"].float() 105 | if self.img_normalize is not None: 106 | batch["visual_inputs"] = self.img_normalize( 107 | batch["visual_inputs"]) 108 | if "crop_visual_inputs" in batch: 109 | batch["crop_visual_inputs"] = batch["crop_visual_inputs"].float() 110 | batch["crop_visual_inputs"] = self.img_normalize( 111 | batch["crop_visual_inputs"]) 112 | if "context_visual_inputs" in batch: 113 | batch["context_visual_inputs"] = batch["context_visual_inputs"].float() 114 | batch["context_visual_inputs"] = self.img_normalize( 115 | batch["context_visual_inputs"]) 116 | if is_tuple: 117 | yield task, batch 118 | else: 119 | yield batch 120 | batch = self.next(loader_it) 121 | 122 | def __len__(self): 123 | return len(self.loader) 124 | 125 | def preload(self, it): 126 | try: 127 | self.batch = next(it) 128 | except StopIteration: 129 | self.batch = None 130 | return 131 | # if record_stream() doesn't work, another option is to make sure 132 | # device inputs are created on the main stream. 133 | # self.next_input_gpu = torch.empty_like(self.next_input, 134 | # device='cuda') 135 | # self.next_target_gpu = torch.empty_like(self.next_target, 136 | # device='cuda') 137 | # Need to make sure the memory allocated for next_* is not still in use 138 | # by the main stream at the time we start copying to next_*: 139 | # self.stream.wait_stream(torch.cuda.current_stream()) 140 | with torch.cuda.stream(self.stream): 141 | self.batch = move_to_cuda(self.batch) 142 | # more code for the alternative if record_stream() doesn't work: 143 | # copy_ will record the use of the pinned source tensor in this 144 | # side stream. 145 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 146 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 147 | # self.next_input = self.next_input_gpu 148 | # self.next_target = self.next_target_gpu 149 | 150 | def next(self, it): 151 | torch.cuda.current_stream().wait_stream(self.stream) 152 | batch = self.batch 153 | if batch is not None: 154 | record_cuda_stream(batch) 155 | self.preload(it) 156 | return batch 157 | 158 | def __getattr__(self, name): 159 | method = self.loader.__getattribute__(name) 160 | return method 161 | 162 | 163 | class InfiniteIterator(object): 164 | """iterate an iterable oobject infinitely""" 165 | def __init__(self, iterable): 166 | self.iterable = iterable 167 | self.iterator = iter(iterable) 168 | 169 | def __iter__(self): 170 | while True: 171 | try: 172 | batch = next(self.iterator) 173 | except StopIteration: 174 | self.iterator = iter(self.iterable) 175 | batch = next(self.iterator) 176 | yield batch 177 | -------------------------------------------------------------------------------- /src/datasets/dataset_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from PIL import Image 3 | import io 4 | import av 5 | import torch 6 | import numpy as np 7 | import lmdb 8 | import random 9 | import decord 10 | from decord import VideoReader 11 | from src.datasets.data_utils import ( 12 | ImageResize, ImagePad, image_to_tensor) 13 | from src.utils.load_save import LOGGER 14 | 15 | decord.bridge.set_bridge("torch") 16 | 17 | 18 | class AlproBaseDataset(Dataset): 19 | """ 20 | datalist: list(dicts) # lightly pre-processed 21 | { 22 | "type": "image", 23 | "filepath": "/abs/path/to/COCO_val2014_000000401092.jpg", 24 | "text": "A plate of food and a beverage are on a table.", 25 | # should be tokenized and digitized first? 26 | ... 27 | } 28 | tokenizer: 29 | max_img_size: int, 30 | max_txt_len: int, max text sequence length, including special tokens. 31 | fps: float, frame per second 32 | num_frm: #frames to use as input. 33 | """ 34 | 35 | def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type='lmdb', fps=3, num_frm=3, 36 | frm_sampling_strategy="rand", max_img_size=-1, max_txt_len=20): 37 | self.fps = fps 38 | self.num_frm = num_frm 39 | self.frm_sampling_strategy = frm_sampling_strategy 40 | self.datalist = datalist 41 | self.tokenizer = tokenizer 42 | self.max_txt_len = max_txt_len 43 | self.max_img_size = max_img_size 44 | self.img_resize = ImageResize( 45 | max_img_size, 46 | "bilinear") # longer side will be resized to 1000 47 | self.img_pad = ImagePad( 48 | max_img_size, max_img_size) # pad to 1000 * 1000 49 | 50 | self.img_db_type = img_db_type 51 | 52 | assert img_db_type in ['lmdb', 'rawvideo'], "Invalid type for img_db_type, expected {'lmdb', 'rawvideo'}, found {}.".format(img_db_type) 53 | 54 | if self.img_db_type == 'lmdb': 55 | self.env = lmdb.open( 56 | img_lmdb_dir, readonly=True, 57 | create=False) # readahead=not _check_distributed() 58 | self.txn = self.env.begin(buffers=True) 59 | else: 60 | self.img_db_dir = img_lmdb_dir 61 | 62 | def __len__(self): 63 | return len(self.datalist) 64 | 65 | def __getitem__(self, index): 66 | raise NotImplementedError 67 | 68 | def _load_img(self, img_id): 69 | """Load and apply transformation to image 70 | 71 | Returns: 72 | torch.float, in [0, 255], (n_frm=1, c, h, w) 73 | """ 74 | raw_img = load_decompress_img_from_lmdb_value( 75 | self.txn.get(str(img_id).encode("utf-8")) 76 | ) 77 | image_np = np.array(raw_img, dtype=np.uint8) # (h, w, c) 78 | raw_img_tensor = image_to_tensor( 79 | image_np, keepdim=False).float() # (c, h, w) [0, 255] 80 | resized_img = self.img_resize(raw_img_tensor) 81 | transformed_img = self.img_pad( 82 | resized_img) # (n_frm=1, c, h, w) 83 | return transformed_img 84 | 85 | @classmethod 86 | def _is_extreme_aspect_ratio(cls, tensor, max_ratio=5.): 87 | """ find extreme aspect ratio, where longer side / shorter side > max_ratio 88 | Args: 89 | tensor: (*, H, W) 90 | max_ratio: float, max ratio (>1). 91 | """ 92 | h, w = tensor.shape[-2:] 93 | return h / float(w) > max_ratio or h / float(w) < 1 / max_ratio 94 | 95 | def _load_video(self, video_id, num_clips=None, clip_idx=None, 96 | safeguard_duration=False, video_max_pts=None): 97 | """Load and sample frames from video. 98 | Apply transformation to the sampled frames. 99 | 100 | Sample a clip: 101 | - random: set num_clips and clip_idx to be None 102 | - uniform: set num_clips=N, clip_idx=idx. e.g., num_clips=3 103 | and clip_idx=1 will first segment the video into 3 clips, 104 | then sample the 2nd clip. 105 | 106 | Returns: 107 | torch.float, in [0, 255], (n_frm=T, c, h, w) 108 | """ 109 | assert (num_clips is None) == (clip_idx is None), "Both None, or both not None" 110 | # (T, C, H, W) [0, 255] 111 | io_stream = io.BytesIO(self.txn.get(str(video_id).encode("utf-8"))) 112 | raw_sampled_frms, video_max_pts = extract_frames_from_video_binary( 113 | io_stream, 114 | target_fps=self.fps, 115 | num_frames=self.num_frm, 116 | multi_thread_decode=False, 117 | sampling_strategy=self.frm_sampling_strategy, 118 | num_clips=num_clips, 119 | clip_idx=clip_idx, 120 | safeguard_duration=safeguard_duration, 121 | video_max_pts=video_max_pts 122 | ) 123 | 124 | if raw_sampled_frms is None: 125 | return None, None 126 | elif self._is_extreme_aspect_ratio(raw_sampled_frms, max_ratio=5.): 127 | print( 128 | f"Found extreme aspect ratio for video id {video_id}. Skip it") 129 | return None, None 130 | 131 | raw_sampled_frms = raw_sampled_frms.float() 132 | resized_frms = self.img_resize(raw_sampled_frms) 133 | padded_frms = self.img_pad(resized_frms) 134 | return padded_frms, video_max_pts 135 | 136 | 137 | def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1): 138 | try: 139 | if not height or not width: 140 | vr = VideoReader(video_path) 141 | else: 142 | vr = VideoReader(video_path, width=width, height=height) 143 | 144 | vlen = len(vr) 145 | 146 | if start_time or end_time: 147 | assert fps > 0, 'must provide video fps if specifying start and end time.' 148 | 149 | start_idx = min(int(start_time * fps), vlen) 150 | end_idx = min(int(end_time * fps), vlen) 151 | else: 152 | start_idx, end_idx = 0, vlen 153 | 154 | if self.frm_sampling_strategy == 'uniform': 155 | frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int) 156 | elif self.frm_sampling_strategy == 'nlvl_uniform': 157 | frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int) 158 | elif self.frm_sampling_strategy == 'nlvl_rand': 159 | frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int) 160 | 161 | # generate some random perturbations 162 | strides = [frame_indices[i] - frame_indices[i-1] for i in range(1, len(frame_indices))] + [vlen - frame_indices[-1]] 163 | pertube = np.array([np.random.randint(0, stride) for stride in strides]) 164 | 165 | frame_indices = frame_indices + pertube 166 | 167 | elif self.frm_sampling_strategy == 'rand': 168 | frame_indices = sorted(random.sample(range(vlen), self.num_frm)) 169 | elif self.frm_sampling_strategy == 'headtail': 170 | frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2)) 171 | frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2)) 172 | frame_indices = frame_indices_head + frame_indices_tail 173 | else: 174 | raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy)) 175 | 176 | raw_sample_frms = vr.get_batch(frame_indices) 177 | except Exception as e: 178 | return None 179 | 180 | raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) 181 | 182 | return raw_sample_frms 183 | 184 | def img_collate(imgs): 185 | """ 186 | Args: 187 | imgs: 188 | 189 | Returns: 190 | torch.tensor, (B, 3, H, W) 191 | """ 192 | w = imgs[0].width 193 | h = imgs[0].height 194 | tensor = torch.zeros( 195 | (len(imgs), 3, h, w), dtype=torch.uint8).contiguous() 196 | for i, img in enumerate(imgs): 197 | nump_array = np.array(img, dtype=np.uint8) 198 | if (nump_array.ndim < 3): 199 | nump_array = np.expand_dims(nump_array, axis=-1) 200 | # (H, W, 3) --> (3, H, W) 201 | nump_array = np.rollaxis(nump_array, 2) 202 | tensor[i] += torch.from_numpy(nump_array) 203 | return tensor 204 | -------------------------------------------------------------------------------- /src/datasets/dataset_pretrain_sparse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | import torch 6 | import spacy 7 | from torch.utils.data.dataloader import default_collate 8 | from src.utils.logger import LOGGER 9 | from src.utils.basic_utils import flat_list_of_lists, save_frames_grid 10 | from src.datasets.data_utils import VideoRandomSquareCrop, VideoResizeSquare, mask_batch_text_tokens, select_batch_text_pivots 11 | from src.datasets.dataset_base import AlproBaseDataset, img_collate 12 | 13 | from src.datasets.randaugment import TemporalConsistentRandomAugment, RandomAugment 14 | 15 | from torch.utils.data import Dataset 16 | 17 | from torchvision import transforms 18 | from PIL import Image 19 | import numpy as np 20 | 21 | 22 | class AlproPretrainSparseDataset(AlproBaseDataset): 23 | """ 24 | datalist: list(tuples) each tuple is (img_id, list(dicts)), 25 | each dict { 26 | "type": "image", 27 | "filepath": "/abs/path/to/COCO_val2014_000000401092.jpg", 28 | "text": "A plate of food and a beverage are on a table.", # should be tokenized and digitized first? 29 | ... 30 | } 31 | tokenizer: 32 | max_img_size: int, 33 | max_txt_len: int, max text sequence length, including special tokens. 34 | vis_format: str, image or video, used to decide data loading method. 35 | """ 36 | def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type, txt_dir, 37 | video_fmt='.mp4', crop_size=256, resize_size=288, fps=3, num_frm=3, frm_sampling_strategy="rand", 38 | max_img_size=1000, max_txt_len=20, 39 | use_itm=True, is_train=True): 40 | super(AlproPretrainSparseDataset, self).__init__( 41 | datalist, tokenizer, img_lmdb_dir, 42 | img_db_type=img_db_type, 43 | fps=fps, 44 | num_frm=num_frm, 45 | frm_sampling_strategy=frm_sampling_strategy, 46 | max_img_size=max_img_size, 47 | max_txt_len=max_txt_len) 48 | self.use_itm = use_itm 49 | 50 | self.txt_dir = txt_dir 51 | self.video_fmt = video_fmt 52 | 53 | self.crop_size = crop_size 54 | self.video_random_cropper = VideoRandomSquareCrop(crop_size) 55 | 56 | self.resize_size = resize_size 57 | 58 | self.is_train = is_train 59 | 60 | if self.is_train: 61 | self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip']) 62 | else: 63 | self.randaug = None 64 | 65 | def __len__(self): 66 | return len(self.datalist) 67 | 68 | def __getitem__(self, index): 69 | start_time = None 70 | end_time = None 71 | 72 | # fetch video 73 | num_retries = 10 # skip error videos 74 | 75 | for _ in range(num_retries): 76 | data_sample = self.datalist.iloc[index] 77 | 78 | video_id = str(data_sample.video_id) 79 | txt_len = int(data_sample.txt_len) 80 | 81 | if hasattr(data_sample, 'text'): 82 | text = data_sample.text.strip() 83 | else: 84 | raise NotImplementedError("Un-supported text annotation format.") 85 | 86 | # fetch video 87 | video_path = os.path.join(self.img_db_dir, video_id + self.video_fmt) 88 | 89 | # read with retries 90 | for i in range(3): 91 | img_array = self._load_video_from_path_decord(video_path, height=self.resize_size, width=self.resize_size) 92 | 93 | if img_array is not None: 94 | break 95 | 96 | if img_array is not None: 97 | t, c, h, w = img_array.shape 98 | 99 | # Select a random video if the current video was not able to access. 100 | if img_array is None: 101 | LOGGER.info(f"Failed to load examples with video: {video_path}. " 102 | f"Will randomly sample an example as a replacement.") 103 | index = random.randint(0, len(self) - 1) 104 | continue 105 | else: 106 | # square crop 107 | img_array = self.video_random_cropper(img_array) 108 | 109 | if self.randaug: 110 | img_array = self.randaug(img_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 111 | 112 | break 113 | else: 114 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") 115 | 116 | examples = [{'text_str': text, 'itm_label': 1}] 117 | 118 | return dict( 119 | img=img_array, # (T, C, H, W) 120 | examples=examples, 121 | n_examples=len(examples), # used to create image feature copies. 122 | type='video' 123 | ) 124 | 125 | class PretrainImageTextDataset(Dataset): 126 | def __init__(self, datalist, tokenizer, is_train=True, crop_size=256, resize_size=288, num_frm=4, max_txt_len=40): 127 | self.datalist = datalist 128 | self.max_txt_len = max_txt_len 129 | 130 | self.crop_size = crop_size 131 | self.resize_size = resize_size 132 | self.num_frms = num_frm 133 | 134 | self.is_train = is_train 135 | 136 | self.transform = transforms.Compose([ 137 | transforms.RandomResizedCrop(self.crop_size, scale=(0.2, 1.0), interpolation=Image.BICUBIC), 138 | transforms.RandomHorizontalFlip(), 139 | RandomAugment(2,7,isPIL=True,augs=['Identity','Brightness','Sharpness', 140 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']) 141 | ]) 142 | 143 | def __len__(self): 144 | return len(self.datalist) 145 | 146 | def __getitem__(self, index): 147 | start_time = None 148 | end_time = None 149 | 150 | # fetch video 151 | num_retries = 10 # skip error videos 152 | 153 | for _ in range(num_retries): 154 | data_sample = self.datalist[index] 155 | 156 | try: 157 | if type(data_sample['caption']) == list: 158 | text = random.choice(data_sample['caption']) 159 | else: 160 | text = data_sample['caption'] 161 | 162 | img_path = data_sample['image'] 163 | img_arr = Image.open(img_path).convert('RGB') 164 | img_arr = self.transform(img_arr) 165 | img_arr = np.asarray(img_arr, dtype=np.float32).transpose(2, 0, 1) 166 | img_arr = torch.from_numpy(img_arr).unsqueeze(0) 167 | img_arr = img_arr.repeat(self.num_frms, 1, 1, 1) 168 | 169 | except Exception as e: 170 | img_arr = None 171 | 172 | if img_arr is not None: 173 | t, c, h, w = img_arr.shape 174 | 175 | # Select a random video if the current video was not able to access. 176 | if img_arr is None: 177 | LOGGER.info(f"Failed to load examples with image: {img_path}. " 178 | f"Will randomly sample an example as a replacement.") 179 | index = random.randint(0, len(self) - 1) 180 | continue 181 | else: 182 | break 183 | else: 184 | raise RuntimeError(f"Failed to fetch image after {num_retries} retries.") 185 | 186 | examples = [{'text_str': text, 'itm_label': 1}] 187 | 188 | return dict( 189 | img=img_arr, # (T, C, H, W) 190 | examples=examples, 191 | n_examples=len(examples), # used to create image feature copies. 192 | type='img' 193 | ) 194 | 195 | 196 | class PretrainCollator(object): 197 | """is_train is kept here if we want to remove 198 | the randomness during validation of MLM accuracy. 199 | In that case, instantiate two PretrainCollator""" 200 | def __init__(self, tokenizer, 201 | mlm=True, mlm_probability=0.15, 202 | patch_size=16, 203 | mpm=True, 204 | max_length=20, is_train=True): 205 | self.tokenizer = tokenizer 206 | self.mlm = mlm 207 | self.mlm_probability = mlm_probability 208 | self.max_length = max_length 209 | self.is_train = is_train 210 | 211 | self.mpm = mpm 212 | self.patch_size = patch_size 213 | 214 | def collate_batch(self, batch): 215 | if isinstance(batch[0]["img"], torch.Tensor): 216 | v_collate = default_collate 217 | else: 218 | v_collate = img_collate 219 | visual_inputs = v_collate([d["img"] for d in batch]) # (B, #frm=1 or T, 3, H, W) 220 | # group data 221 | text_examples = flat_list_of_lists([d["examples"] for d in batch]) 222 | n_examples_list = [d["n_examples"] for d in batch] # (B, ) 223 | # group elements data 224 | batch_enc = self.tokenizer.batch_encode_plus( 225 | [d["text_str"] for d in text_examples], 226 | max_length=self.max_length, 227 | padding='max_length', 228 | return_tensors="pt", 229 | truncation=True 230 | ) 231 | text_input_ids = batch_enc.input_ids # (B, L) 232 | text_input_ids_no_mask = text_input_ids.clone() 233 | 234 | if self.mlm: 235 | text_input_ids, mlm_labels = mask_batch_text_tokens( 236 | text_input_ids, self.tokenizer, 237 | is_train=self.is_train) # make mlm data 238 | else: 239 | text_input_ids, mlm_labels = text_input_ids, None 240 | 241 | text_input_mask = batch_enc.attention_mask # (B, L) 242 | itm_labels = default_collate( 243 | [d["itm_label"] for d in text_examples]) # (B, ) 244 | 245 | erase_elems = [random_erase(e, patch_size=self.patch_size) for e in visual_inputs.clone()] 246 | 247 | if self.mpm: 248 | crop_visual_inputs = v_collate([elems[0] for elems in erase_elems]) 249 | mpm_masks = v_collate([elems[1] for elems in erase_elems]) 250 | context_visual_inputs = v_collate([elems[2] for elems in erase_elems]) 251 | 252 | return dict( 253 | visual_inputs=visual_inputs, # (B, #frm=1 or T, H, W, C) 254 | crop_visual_inputs=crop_visual_inputs, 255 | context_visual_inputs=context_visual_inputs, 256 | mpm_mask=mpm_masks, 257 | text_input_ids=text_input_ids_no_mask, 258 | mlm_text_input_ids=text_input_ids, 259 | mlm_labels=mlm_labels, 260 | text_input_mask=text_input_mask, # used to exclude [PAD] token 261 | itm_labels=itm_labels, 262 | n_examples_list=n_examples_list, # used to create image feature copies. 263 | type=batch[0]['type'] 264 | ) 265 | else: 266 | return dict( 267 | visual_inputs=visual_inputs, # (B, #frm=1 or T, H, W, C) 268 | text_input_ids=text_input_ids_no_mask, 269 | mlm_text_input_ids=text_input_ids, 270 | mlm_labels=mlm_labels, 271 | text_input_mask=text_input_mask, # used to exclude [PAD] token 272 | itm_labels=itm_labels, 273 | n_examples_list=n_examples_list, # used to create image feature copies. 274 | type=batch[0]['type'] 275 | ) 276 | 277 | def random_erase(input_img, patch_size, s_l=0.3, s_h=0.5, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255): 278 | assert input_img.ndim == 4 279 | img_t, img_c, img_h, img_w = input_img.shape 280 | 281 | while True: 282 | s = np.random.uniform(s_l, s_h) * img_h * img_w 283 | r = np.random.uniform(r_1, r_2) 284 | w = int(np.sqrt(s / r)) 285 | h = int(np.sqrt(s * r)) 286 | left = np.random.randint(0, img_w) 287 | top = np.random.randint(0, img_h) 288 | 289 | w = w - w % patch_size 290 | h = h - h % patch_size 291 | 292 | left = left - left % patch_size 293 | top = top - top % patch_size 294 | 295 | if left + w <= img_w and top + h <= img_h: 296 | break 297 | 298 | context_img = input_img.clone() 299 | context_img[:, :, top: top + h, left: left + w] = 0 300 | 301 | input_img = input_img[:, :, top: top + h, left: left + w] 302 | pad = (left, img_w - left - w, top, img_h - top - h) 303 | input_img = torch.nn.functional.pad(input_img, pad=pad, mode='constant', value=0.0) 304 | 305 | img_masks = torch.ones_like(input_img) 306 | img_masks[:, :, top: top+h, left: left+w] = 0 307 | 308 | img_masks = torch.nn.functional.avg_pool2d(img_masks.float(), kernel_size=(patch_size, patch_size), stride=patch_size) 309 | img_masks = torch.mean(img_masks, dim=(0, 1)) 310 | 311 | return input_img, img_masks, context_img -------------------------------------------------------------------------------- /src/datasets/dataset_video_qa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | import copy 6 | from torch.utils.data.dataloader import default_collate 7 | from src.utils.basic_utils import flat_list_of_lists 8 | from src.utils.load_save import LOGGER 9 | from src.datasets.dataset_base import AlproBaseDataset 10 | from src.datasets.randaugment import TemporalConsistentRandomAugment 11 | 12 | 13 | class AlproVideoQADataset(AlproBaseDataset): 14 | """ This should work for both train and test (where labels are not available). 15 | task_type: str, one of [action, frameqa, transition] 16 | where action and transition are multiple-choice QA, 17 | frameqa is opened QA similar to VQA. 18 | datalist: list(tuples) each tuple is (img_id, list(dicts)), 19 | each dict 20 | tokenizer: 21 | max_img_size: int, 22 | max_txt_len: int, max text sequence length, including special tokens. 23 | return_label: bool, whether return label in __getitem__ 24 | random_sample_clips: 25 | """ 26 | open_ended_qa_names = ["frameqa", "msrvtt_qa", "msvd_qa"] 27 | 28 | def __init__(self, task_type, datalist, tokenizer, img_lmdb_dir, 29 | fps=3, num_frm=3, frm_sampling_strategy="rand", 30 | max_img_size=1000, max_txt_len=20, ans2label=None, 31 | ensemble_n_clips=1, return_label=True, is_train=False, random_sample_clips=True, 32 | video_fmt='.mp4', img_db_type='lmdb'): 33 | super(AlproVideoQADataset, self).__init__( 34 | datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type, 35 | fps=fps, num_frm=num_frm, 36 | frm_sampling_strategy=frm_sampling_strategy, 37 | max_img_size=max_img_size, max_txt_len=max_txt_len) 38 | self.ensemble_n_clips = ensemble_n_clips 39 | self.return_label = return_label 40 | self.is_train = is_train 41 | self.task_type = task_type 42 | self.ans2label = ans2label 43 | self.num_labels = len(ans2label) 44 | self.random_sample_clips = random_sample_clips 45 | self.label2ans = {v: k for k, v in ans2label.items()} 46 | self.qid2data = {d["question_id"]: d for group in datalist for d in group[1]} 47 | 48 | self.video_fmt = video_fmt 49 | 50 | if self.is_train: 51 | self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip']) 52 | else: 53 | self.randaug = None 54 | 55 | def __len__(self): 56 | return len(self.datalist) 57 | 58 | 59 | def __getitem__(self, index): 60 | # skip error videos: 61 | num_retries = 5 62 | for _ in range(num_retries): 63 | vid_id, examples = self.datalist[index] # one video with multiple examples 64 | if self.ensemble_n_clips > 1: 65 | raise NotImplementedError('Do not support multiple clips for now.') 66 | else: 67 | video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) 68 | vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) 69 | 70 | # Select a random video if the current video was not able to access. 71 | if vid_frm_array is None: 72 | LOGGER.info(f"Failed to load examples with video: {vid_id}. " 73 | f"Will randomly sample an example as a replacement.") 74 | index = random.randint(0, len(self) - 1) 75 | continue 76 | 77 | if self.randaug: 78 | vid_frm_array = self.randaug(vid_frm_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 79 | 80 | examples = [self._get_single_example(e) for e in examples] 81 | return dict( 82 | vid=vid_frm_array, 83 | examples=examples, 84 | n_examples=len(examples) # used to create image feature copies. 85 | ) 86 | else: 87 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.") 88 | 89 | def _get_single_example(self, data): 90 | example = dict( 91 | q_str=data["question"], 92 | question_id=data["question_id"], 93 | label=data["answer"] 94 | ) 95 | if self.task_type in self.open_ended_qa_names: 96 | if self.return_label: 97 | example["label"] = self.ans2label[example["label"]] 98 | if not self.return_label: 99 | example["label"] = None 100 | return example 101 | 102 | def evaluate_qa(self, results): 103 | """ 104 | Args: 105 | results: list(dict), 106 | each dict is 107 | { 108 | "question_id": int, 109 | "answer": int or float, either answer_idx (int) 110 | } 111 | Returns: 112 | TGIF-QA score 113 | """ 114 | preds = [] 115 | gts = [] 116 | # for frameQA 117 | answer_types = [] 118 | answer_type2idx = dict( 119 | frameqa={"object": 0, "number": 1, "color": 2, "location": 3}, 120 | msrvtt_qa={k: idx for idx, k in enumerate(["what", "who", "how", "where", "when"])}, 121 | msvd_qa={k: idx for idx, k in enumerate(["what", "who", "how", "where", "when"])} 122 | ) 123 | 124 | qid2pred_ans = {r["question_id"]: r["answer"] for r in results} 125 | if self.task_type in self.open_ended_qa_names: # convert ans_idx, int --> str 126 | qid2pred_ans = {k: self.label2ans[v] for k, v in qid2pred_ans.items()} 127 | 128 | for qid, pred_ans in qid2pred_ans.items(): 129 | preds.append(pred_ans) 130 | 131 | gt_data = self.qid2data[qid] 132 | gt_ans = gt_data["answer"] 133 | if self.task_type in self.open_ended_qa_names: 134 | answer_types.append(answer_type2idx[self.task_type][gt_data["answer_type"]]) 135 | gts.append(gt_ans) 136 | 137 | preds = np.array(preds) 138 | gts = np.array(gts) 139 | metrics = dict() 140 | # preds and gts are array of strings 141 | metrics["overall_acc"] = float(np.mean(preds == gts)) 142 | if self.task_type in self.open_ended_qa_names: 143 | answer_types = np.array(answer_types) 144 | ratios = dict() 145 | for ans_type, ans_type_idx in answer_type2idx[self.task_type].items(): 146 | answer_type_mask = answer_types == ans_type_idx 147 | answer_type_corrects = ( 148 | preds[answer_type_mask] == gts[answer_type_mask]) 149 | metrics[f"{ans_type}_acc"] = float( 150 | np.mean(answer_type_corrects)) if len(answer_type_corrects) != 0 else 0 151 | ratios[f"{ans_type}_ratio"] = [ 152 | 1. * len(answer_type_corrects) / len(answer_types), 153 | len(answer_type_corrects)] 154 | metrics["ratios"] = ratios 155 | return metrics 156 | 157 | 158 | class VideoQACollator(object): 159 | def __init__(self, tokenizer, max_length=20, task_type="action", n_options=5): 160 | self.tokenizer = tokenizer 161 | self.max_length = max_length 162 | self.task_type = task_type 163 | self.n_options = n_options 164 | 165 | def collate_batch(self, batch): 166 | v_collate = default_collate 167 | visual_inputs = v_collate([d["vid"] for d in batch]) # (B, T, 3, H, W) 168 | # group data 169 | text_examples = flat_list_of_lists([d["examples"] for d in batch]) 170 | n_examples_list = [d["n_examples"] for d in batch] # (B, ) 171 | # group elements data 172 | # directly concatenate question and option as a single seq. 173 | if self.task_type in ["action", "transition"]: 174 | text_str_list = flat_list_of_lists( 175 | [[d["q_str"] + " " + d["options_str_list"][i] for i in range(self.n_options)] 176 | for d in text_examples] 177 | ) # (B * n_options, ) 178 | else: 179 | text_str_list = [d["q_str"] for d in text_examples] # (B, ) 180 | batch_enc = self.tokenizer.batch_encode_plus( 181 | text_str_list, 182 | max_length=self.max_length, 183 | padding='max_length', 184 | return_tensors="pt", 185 | truncation=True 186 | ) 187 | text_input_ids = batch_enc.input_ids # (B, L) 188 | text_input_mask = batch_enc.attention_mask # (B, L) 189 | 190 | labels = default_collate([int(d["label"]) for d in text_examples]) \ 191 | if text_examples[0]["label"] is not None else None # (B, #ans) 192 | question_ids = [d["question_id"] for d in text_examples] 193 | return dict( 194 | visual_inputs=visual_inputs, # (B, #frm, H, W, C) 195 | text_input_ids=text_input_ids, 196 | text_input_mask=text_input_mask, 197 | question_ids=question_ids, 198 | labels=labels, 199 | n_examples_list=n_examples_list # used to create image feature copies. 200 | ) 201 | -------------------------------------------------------------------------------- /src/datasets/dataset_video_retrieval.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import os 4 | import torch 5 | import numpy as np 6 | from torch.utils.data.dataloader import default_collate 7 | from src.utils.basic_utils import flat_list_of_lists 8 | from src.utils.load_save import LOGGER 9 | from src.datasets.dataset_base import AlproBaseDataset 10 | from src.datasets.randaugment import TemporalConsistentRandomAugment 11 | 12 | 13 | class AlproVideoRetrievalDataset(AlproBaseDataset): 14 | """ This should work for both train and test (where labels are not available). 15 | datalist: list(tuples) each tuple is (img_id, list(dicts)), 16 | each dict 17 | tokenizer: 18 | max_img_size: int, 19 | max_txt_len: int, max text sequence length, including special tokens. 20 | random_sample_clips: bool, whether using randomly sampled N clips or always use uniformly sampled N clips 21 | """ 22 | def __init__(self, datalist, tokenizer, img_lmdb_dir, 23 | fps=3, num_frm=3, frm_sampling_strategy="rand", 24 | max_img_size=1000, max_txt_len=40, itm_neg_size=1, 25 | ensemble_n_clips=1, random_sample_clips=True, 26 | video_fmt='.mp4', img_db_type='lmdb', is_train=False): 27 | super(AlproVideoRetrievalDataset, self).__init__( 28 | datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type, 29 | fps=fps, num_frm=num_frm, 30 | frm_sampling_strategy=frm_sampling_strategy, 31 | max_img_size=max_img_size, max_txt_len=max_txt_len) 32 | self.ensemble_n_clips = ensemble_n_clips 33 | self.num_labels = 2 34 | self.itm_neg_size = itm_neg_size 35 | self.random_sample_clips = random_sample_clips 36 | self.id2data = { 37 | d["id"]: d for group in datalist for d in group[1]} 38 | 39 | self.is_train = is_train 40 | self.video_fmt = video_fmt 41 | 42 | if self.is_train: 43 | self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip']) 44 | else: 45 | self.randaug = None 46 | 47 | def __len__(self): 48 | return len(self.datalist) 49 | 50 | def __getitem__(self, index): 51 | # skip error videos: 52 | num_retries = 5 53 | for _ in range(num_retries): 54 | vid_id, examples = self.datalist[index] # one video with multiple examples 55 | if self.ensemble_n_clips > 1: 56 | raise NotImplementedError('Do not support multiple clips for now.') 57 | else: 58 | video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) 59 | vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) 60 | 61 | # Select a random video if the current video was not able to access. 62 | if vid_frm_array is None: 63 | LOGGER.info(f"Failed to load examples with video: {vid_id}. " 64 | f"Will randomly sample an example as a replacement.") 65 | index = random.randint(0, len(self) - 1) 66 | continue 67 | sampled_examples = [] 68 | for e in examples: 69 | s = self._get_single_example(e, index) 70 | if isinstance(s, dict): 71 | sampled_examples.append(s) 72 | else: 73 | sampled_examples.extend(s) 74 | return dict( 75 | vid=vid_frm_array, 76 | examples=sampled_examples, 77 | n_examples=len(sampled_examples) # used to create image feature copies. 78 | ) 79 | else: 80 | raise RuntimeError( 81 | f"Failed to fetch video after {num_retries} retries.") 82 | 83 | def _get_single_example(self, data, index): 84 | examples = [] 85 | 86 | text_str = data["txt"] 87 | itm_label = 1 # positive pair 88 | examples.append(dict( 89 | text_str=text_str, 90 | itm_label=itm_label 91 | )) 92 | return examples 93 | 94 | 95 | class VideoRetrievalCollator(object): 96 | def __init__(self, tokenizer, max_length=40): 97 | self.tokenizer = tokenizer 98 | self.max_length = max_length 99 | 100 | def collate_batch(self, batch): 101 | # FIXME there is a chance that two captions associated with the same video are batched together. Might need to fix. 102 | v_collate = default_collate 103 | visual_inputs = v_collate([d["vid"] for d in batch]) # (B, T, 3, H, W) 104 | # group data 105 | text_examples = flat_list_of_lists([d["examples"] for d in batch]) 106 | n_examples_list = [d["n_examples"] for d in batch] # (B, ) 107 | # group elements data 108 | # directly concatenate question and option as a single seq. 109 | text_str_list = [d["text_str"] for d in text_examples] # (B, ) 110 | batch_enc = self.tokenizer.batch_encode_plus( 111 | text_str_list, 112 | max_length=self.max_length, 113 | padding='max_length', 114 | return_tensors="pt", 115 | truncation=True 116 | ) 117 | text_input_ids = batch_enc.input_ids # (B, L) 118 | text_input_mask = batch_enc.attention_mask # (B, L) 119 | 120 | if "itm_label" in text_examples[0]: 121 | itm_labels = default_collate( 122 | [d["itm_label"] for d in text_examples]) # (B, ) 123 | else: 124 | itm_labels = None 125 | 126 | if "id" in text_examples[0]: 127 | caption_ids = [d["id"] for d in text_examples] # (B, ) 128 | else: 129 | caption_ids = None 130 | collated_batch = dict( 131 | visual_inputs=visual_inputs, # (B, #frm, H, W, C) 132 | text_input_ids=text_input_ids, 133 | text_input_mask=text_input_mask, 134 | caption_ids=caption_ids, # list(int), example ids, 135 | labels=itm_labels, 136 | n_examples_list=n_examples_list # used to create image feature copies. 137 | ) 138 | if "vid_id" in batch[0] and len(batch) == 1: 139 | collated_batch["vid_id"] = batch[0]["vid_id"] 140 | return collated_batch 141 | 142 | 143 | class AlproVideoRetrievalEvalDataset(AlproBaseDataset): 144 | """ Sample by video/image, calculate scores between each video with all the text 145 | and loop through all the videos. Each batch will only contain a single video, 146 | but multiple text. 147 | 148 | datalist: list(dict), each dict 149 | tokenizer: 150 | max_img_size: int, 151 | max_txt_len: int, max text sequence length, including special tokens. 152 | """ 153 | def __init__(self, datalist, tokenizer, img_lmdb_dir, 154 | fps=3, num_frm=3, frm_sampling_strategy="rand", 155 | max_img_size=1000, max_txt_len=40, ensemble_n_clips=1, 156 | video_fmt='.mp4', img_db_type='lmdb'): 157 | self.ensemble_n_clips = ensemble_n_clips 158 | super(AlproVideoRetrievalEvalDataset, self).__init__( 159 | datalist, tokenizer, img_lmdb_dir, 160 | fps=fps, num_frm=num_frm, 161 | frm_sampling_strategy=frm_sampling_strategy, 162 | max_img_size=max_img_size, max_txt_len=max_txt_len, 163 | img_db_type=img_db_type) 164 | # id is unique id per caption/example 165 | for i, d in enumerate(self.datalist): 166 | assert i == d["id"] 167 | self.gt_cap_id2vid_id = {d["id"]: d["vid_id"] for d in datalist} 168 | self.cap_id2data = {d["id"]: d for d in datalist} 169 | self.batches, self.text_batch = self._prepare_batches_by_video() 170 | self.id2data = {d["id"]: d for d in self.datalist} 171 | 172 | self.video_fmt = video_fmt 173 | 174 | def __len__(self): 175 | return len(self.batches) 176 | 177 | def __getitem__(self, index): 178 | # skip error videos: 179 | batch = dict() 180 | 181 | batch["vid_id"] = self.batches[index]["vid_id"] # one video with multiple examples 182 | batch["examples"] = self.text_batch["examples"] 183 | batch["n_examples"] = self.text_batch["n_examples"] 184 | batch["ids"] = self.text_batch["ids"] 185 | 186 | if self.ensemble_n_clips > 1: 187 | raise NotImplementedError('Do not support multiple clips for now.') 188 | else: 189 | # if self.is_train and self.random_sample_clips: 190 | vid_id = batch["vid_id"] 191 | 192 | video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) 193 | vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) 194 | 195 | batch["vid"] = vid_frm_array 196 | return batch 197 | 198 | def _prepare_batches_by_video(self): 199 | """create batches where each batch contains a single video with multiple text""" 200 | text_list = [] 201 | for d in self.datalist: 202 | text_list.append(dict( 203 | text_str=d["txt"], 204 | id=d["id"], 205 | )) 206 | text_batch = dict( 207 | vid_id=None, 208 | examples=text_list, 209 | n_examples=len(text_list), 210 | ids=[d["id"] for d in text_list] 211 | ) 212 | 213 | # make 1000 batches for 1000video x 1000text combinations. 214 | # each batch contains 1video x 1000text 215 | batches = [] 216 | for idx, d in enumerate(self.datalist): 217 | #_batch = copy.deepcopy(text_batch) 218 | _batch = dict() 219 | _batch["vid_id"] = d["vid_id"] 220 | batches.append(_batch) 221 | return batches, text_batch 222 | -------------------------------------------------------------------------------- /src/datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | ## aug functions 7 | def identity_func(img): 8 | return img 9 | 10 | 11 | def autocontrast_func(img, cutoff=0): 12 | ''' 13 | same output as PIL.ImageOps.autocontrast 14 | ''' 15 | n_bins = 256 16 | 17 | def tune_channel(ch): 18 | n = ch.size 19 | cut = cutoff * n // 100 20 | if cut == 0: 21 | high, low = ch.max(), ch.min() 22 | else: 23 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 24 | low = np.argwhere(np.cumsum(hist) > cut) 25 | low = 0 if low.shape[0] == 0 else low[0] 26 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 27 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 28 | if high <= low: 29 | table = np.arange(n_bins) 30 | else: 31 | scale = (n_bins - 1) / (high - low) 32 | offset = -low * scale 33 | table = np.arange(n_bins) * scale + offset 34 | table[table < 0] = 0 35 | table[table > n_bins - 1] = n_bins - 1 36 | table = table.clip(0, 255).astype(np.uint8) 37 | return table[ch] 38 | 39 | channels = [tune_channel(ch) for ch in cv2.split(img)] 40 | out = cv2.merge(channels) 41 | return out 42 | 43 | 44 | def equalize_func(img): 45 | ''' 46 | same output as PIL.ImageOps.equalize 47 | PIL's implementation is different from cv2.equalize 48 | ''' 49 | n_bins = 256 50 | 51 | def tune_channel(ch): 52 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 53 | non_zero_hist = hist[hist != 0].reshape(-1) 54 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 55 | if step == 0: return ch 56 | n = np.empty_like(hist) 57 | n[0] = step // 2 58 | n[1:] = hist[:-1] 59 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 60 | return table[ch] 61 | 62 | channels = [tune_channel(ch) for ch in cv2.split(img)] 63 | out = cv2.merge(channels) 64 | return out 65 | 66 | 67 | def rotate_func(img, degree, fill=(0, 0, 0)): 68 | ''' 69 | like PIL, rotate by degree, not radians 70 | ''' 71 | H, W = img.shape[0], img.shape[1] 72 | center = W / 2, H / 2 73 | M = cv2.getRotationMatrix2D(center, degree, 1) 74 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 75 | return out 76 | 77 | 78 | def horizontal_flip_func(img): 79 | ''' 80 | [dxli] 81 | horizontally flip an image. 82 | ''' 83 | out = cv2.flip(img, 1) 84 | 85 | return out 86 | 87 | 88 | def solarize_func(img, thresh=128): 89 | ''' 90 | same output as PIL.ImageOps.posterize 91 | ''' 92 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 93 | table = table.clip(0, 255).astype(np.uint8) 94 | out = table[img] 95 | return out 96 | 97 | 98 | def color_func(img, factor): 99 | ''' 100 | same output as PIL.ImageEnhance.Color 101 | ''' 102 | ## implementation according to PIL definition, quite slow 103 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 104 | # out = blend(degenerate, img, factor) 105 | # M = ( 106 | # np.eye(3) * factor 107 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 108 | # )[np.newaxis, np.newaxis, :] 109 | M = ( 110 | np.float32([ 111 | [0.886, -0.114, -0.114], 112 | [-0.587, 0.413, -0.587], 113 | [-0.299, -0.299, 0.701]]) * factor 114 | + np.float32([[0.114], [0.587], [0.299]]) 115 | ) 116 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 117 | return out 118 | 119 | 120 | def contrast_func(img, factor): 121 | """ 122 | same output as PIL.ImageEnhance.Contrast 123 | """ 124 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 125 | table = np.array([( 126 | el - mean) * factor + mean 127 | for el in range(256) 128 | ]).clip(0, 255).astype(np.uint8) 129 | out = table[img] 130 | return out 131 | 132 | 133 | def brightness_func(img, factor): 134 | ''' 135 | same output as PIL.ImageEnhance.Contrast 136 | ''' 137 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 138 | out = table[img] 139 | return out 140 | 141 | 142 | def sharpness_func(img, factor): 143 | ''' 144 | The differences the this result and PIL are all on the 4 boundaries, the center 145 | areas are same 146 | ''' 147 | kernel = np.ones((3, 3), dtype=np.float32) 148 | kernel[1][1] = 5 149 | kernel /= 13 150 | degenerate = cv2.filter2D(img, -1, kernel) 151 | if factor == 0.0: 152 | out = degenerate 153 | elif factor == 1.0: 154 | out = img 155 | else: 156 | out = img.astype(np.float32) 157 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 158 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 159 | out = out.astype(np.uint8) 160 | return out 161 | 162 | 163 | def shear_x_func(img, factor, fill=(0, 0, 0)): 164 | H, W = img.shape[0], img.shape[1] 165 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 166 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 167 | return out 168 | 169 | 170 | def translate_x_func(img, offset, fill=(0, 0, 0)): 171 | ''' 172 | same output as PIL.Image.transform 173 | ''' 174 | H, W = img.shape[0], img.shape[1] 175 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 176 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 177 | return out 178 | 179 | 180 | def translate_y_func(img, offset, fill=(0, 0, 0)): 181 | ''' 182 | same output as PIL.Image.transform 183 | ''' 184 | H, W = img.shape[0], img.shape[1] 185 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 186 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 187 | return out 188 | 189 | 190 | def posterize_func(img, bits): 191 | ''' 192 | same output as PIL.ImageOps.posterize 193 | ''' 194 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 195 | return out 196 | 197 | 198 | def shear_y_func(img, factor, fill=(0, 0, 0)): 199 | H, W = img.shape[0], img.shape[1] 200 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 201 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 202 | return out 203 | 204 | 205 | # def cutout_func(img, pad_size, replace=(0, 0, 0)): 206 | # replace = np.array(replace, dtype=np.uint8) 207 | # H, W = img.shape[0], img.shape[1] 208 | # rh, rw = np.random.random(2) 209 | # pad_size = pad_size // 2 210 | # ch, cw = int(rh * H), int(rw * W) 211 | # x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 212 | # y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 213 | # out = img.copy() 214 | # out[x1:x2, y1:y2, :] = replace 215 | # return out 216 | 217 | 218 | ### level to args 219 | def enhance_level_to_args(MAX_LEVEL): 220 | def level_to_args(level): 221 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 222 | return level_to_args 223 | 224 | 225 | def shear_level_to_args(MAX_LEVEL, replace_value): 226 | def level_to_args(level): 227 | level = (level / MAX_LEVEL) * 0.3 228 | # if np.random.random() > 0.5: level = -level 229 | return (level, replace_value) 230 | 231 | return level_to_args 232 | 233 | 234 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 235 | def level_to_args(level): 236 | level = (level / MAX_LEVEL) * float(translate_const) 237 | # if np.random.random() > 0.5: level = -level 238 | return (level, replace_value) 239 | 240 | return level_to_args 241 | 242 | 243 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 244 | def level_to_args(level): 245 | level = int((level / MAX_LEVEL) * cutout_const) 246 | return (level, replace_value) 247 | 248 | return level_to_args 249 | 250 | 251 | def solarize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 256) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def none_level_to_args(level): 259 | return () 260 | 261 | 262 | def posterize_level_to_args(MAX_LEVEL): 263 | def level_to_args(level): 264 | level = int((level / MAX_LEVEL) * 4) 265 | return (level, ) 266 | return level_to_args 267 | 268 | 269 | def rotate_level_to_args(MAX_LEVEL, replace_value): 270 | def level_to_args(level): 271 | level = (level / MAX_LEVEL) * 30 272 | # if np.random.random() < 0.5: 273 | # level = -level 274 | return (level, replace_value) 275 | 276 | return level_to_args 277 | 278 | 279 | func_dict = { 280 | 'Identity': identity_func, 281 | # 'AutoContrast': autocontrast_func, 282 | 'Equalize': equalize_func, 283 | 'Rotate': rotate_func, 284 | 'Solarize': solarize_func, 285 | 'Color': color_func, 286 | 'Contrast': contrast_func, 287 | 'Brightness': brightness_func, 288 | 'Sharpness': sharpness_func, 289 | 'ShearX': shear_x_func, 290 | 'TranslateX': translate_x_func, 291 | 'TranslateY': translate_y_func, 292 | 'Posterize': posterize_func, 293 | 'ShearY': shear_y_func, 294 | 'HorizontalFlip': horizontal_flip_func # [dxli] 295 | } 296 | 297 | translate_const = 10 298 | MAX_LEVEL = 10 299 | replace_value = (128, 128, 128) 300 | arg_dict = { 301 | 'Identity': none_level_to_args, 302 | # 'AutoContrast': none_level_to_args, 303 | 'Equalize': none_level_to_args, 304 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 305 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 306 | 'Color': enhance_level_to_args(MAX_LEVEL), 307 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 308 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 309 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 310 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 311 | 'TranslateX': translate_level_to_args( 312 | translate_const, MAX_LEVEL, replace_value 313 | ), 314 | 'TranslateY': translate_level_to_args( 315 | translate_const, MAX_LEVEL, replace_value 316 | ), 317 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 318 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 319 | 'HorizontalFlip': none_level_to_args # [dxli] 320 | } 321 | 322 | 323 | class TemporalConsistentRandomAugment(object): 324 | 325 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): 326 | self.N = N 327 | self.M = M 328 | self.p = p 329 | self.tensor_in_tensor_out = tensor_in_tensor_out 330 | if augs: 331 | self.augs = augs 332 | else: 333 | self.augs = list(arg_dict.keys()) 334 | 335 | def get_random_ops(self): 336 | sampled_ops = np.random.choice(self.augs, self.N, replace=False) 337 | # return [(op, 0.5, self.M) for op in sampled_ops] 338 | return [(op, self.M) for op in sampled_ops] 339 | 340 | def __call__(self, frames): 341 | assert frames.shape[-1] == 3, 'Expecting last dimension for 3-channels RGB (b, h, w, c).' 342 | 343 | if self.tensor_in_tensor_out: 344 | frames = frames.numpy().astype(np.uint8) 345 | 346 | num_frames = frames.shape[0] 347 | 348 | ops = num_frames * [self.get_random_ops()] 349 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] 350 | 351 | frames = torch.stack(list(map(self._aug, frames, ops, apply_or_not)), dim=0).float() 352 | 353 | return frames 354 | 355 | def _aug(self, img, ops, apply_or_not): 356 | for i, (name, level) in enumerate(ops): 357 | if not apply_or_not[i]: 358 | continue 359 | args = arg_dict[name](level) 360 | img = func_dict[name](img, *args) 361 | return torch.from_numpy(img) 362 | 363 | class RandomAugment(object): 364 | 365 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 366 | self.N = N 367 | self.M = M 368 | self.isPIL = isPIL 369 | if augs: 370 | self.augs = augs 371 | else: 372 | self.augs = list(arg_dict.keys()) 373 | 374 | def get_random_ops(self): 375 | sampled_ops = np.random.choice(self.augs, self.N) 376 | return [(op, 0.5, self.M) for op in sampled_ops] 377 | 378 | def __call__(self, img): 379 | if self.isPIL: 380 | img = np.array(img) 381 | ops = self.get_random_ops() 382 | for name, prob, level in ops: 383 | if np.random.random() > prob: 384 | continue 385 | args = arg_dict[name](level) 386 | img = func_dict[name](img, *args) 387 | return img 388 | 389 | 390 | def save_frames_grid(img_array, out_path): 391 | import torch 392 | from torchvision.utils import make_grid 393 | from PIL import Image 394 | 395 | if len(img_array.shape) == 3: 396 | img_array = img_array.unsqueeze(0) 397 | elif len(img_array.shape) == 5: 398 | b, t, c, h, w = img_array.shape 399 | img_array = img_array.view(-1, c, h, w) 400 | elif len(img_array.shape) == 4: 401 | pass 402 | else: 403 | raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.') 404 | 405 | assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only." 406 | 407 | grid = make_grid(img_array) 408 | ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy() 409 | 410 | img = Image.fromarray(ndarr) 411 | 412 | img.save(out_path) 413 | 414 | 415 | def stack(data, dim=0): 416 | shape = data[0].shape # need to handle empty list 417 | shape = shape[:dim] + (len(data),) + shape[dim:] 418 | x = torch.cat(data, dim=dim) 419 | x = x.reshape(shape) 420 | # need to handle case where dim=-1 421 | # which is not handled here yet 422 | # but can be done with transposition 423 | return x 424 | 425 | 426 | if __name__ == '__main__': 427 | import decord, os 428 | from decord import VideoReader 429 | decord.bridge.set_bridge('torch') 430 | 431 | root_dir = '/export/share/dongxuli/data/webvid2m/postprocess/downsampled_videos' 432 | video_id = '1058234725.mp4' 433 | 434 | video_path = os.path.join(root_dir, video_id) 435 | vr = VideoReader(video_path) 436 | 437 | frames = vr.get_batch([1, 3, 5, 7, 9]) 438 | frames = frames 439 | 440 | # a = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast', 'Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']) 441 | a = TemporalConsistentRandomAugment(N=1, M=5, augs=['HorizontalFlip']) 442 | 443 | print(frames[0].shape) 444 | save_frames_grid(frames.permute(0, 3, 1, 2), 'before.jpg') 445 | 446 | after_frames = a(frames) 447 | print(after_frames.shape) 448 | 449 | save_frames_grid(after_frames.permute(0, 3, 1, 2), 'after.jpg') -------------------------------------------------------------------------------- /src/modeling/timesformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # from .build import MODEL_REGISTRY, build_model # noqa 4 | # from .custom_video_model_builder import * # noqa 5 | # from .video_model_builder import ResNet, SlowFast # noqa 6 | -------------------------------------------------------------------------------- /src/modeling/timesformer/conv2d_same.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Conv2d w/ Same Padding 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Tuple, Optional 8 | 9 | import math 10 | from typing import List, Tuple 11 | #from .padding import pad_same, get_padding_value 12 | 13 | # Dynamically pad input x with 'SAME' padding for conv with specified args 14 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 15 | ih, iw = x.size()[-2:] 16 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 17 | if pad_h > 0 or pad_w > 0: 18 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 19 | return x 20 | 21 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 22 | def get_same_padding(x: int, k: int, s: int, d: int): 23 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 24 | 25 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 26 | dynamic = False 27 | if isinstance(padding, str): 28 | # for any string padding, the padding will be calculated for you, one of three ways 29 | padding = padding.lower() 30 | if padding == 'same': 31 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 32 | if is_static_pad(kernel_size, **kwargs): 33 | # static case, no extra overhead 34 | padding = get_padding(kernel_size, **kwargs) 35 | else: 36 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 37 | padding = 0 38 | dynamic = True 39 | elif padding == 'valid': 40 | # 'VALID' padding, same as padding=0 41 | padding = 0 42 | else: 43 | # Default to PyTorch style 'same'-ish symmetric padding 44 | padding = get_padding(kernel_size, **kwargs) 45 | return padding, dynamic 46 | 47 | def conv2d_same( 48 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 49 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 50 | x = pad_same(x, weight.shape[-2:], stride, dilation) 51 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 52 | 53 | 54 | class Conv2dSame(nn.Conv2d): 55 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 56 | """ 57 | 58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 59 | padding=0, dilation=1, groups=1, bias=True): 60 | super(Conv2dSame, self).__init__( 61 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 62 | 63 | def forward(self, x): 64 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 65 | 66 | 67 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 68 | padding = kwargs.pop('padding', '') 69 | kwargs.setdefault('bias', False) 70 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 71 | if is_dynamic: 72 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 73 | else: 74 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 75 | -------------------------------------------------------------------------------- /src/modeling/timesformer/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | 3 | from collections import OrderedDict, defaultdict 4 | from copy import deepcopy 5 | from functools import partial 6 | from typing import Dict, List, Tuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class FeatureInfo: 13 | 14 | def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): 15 | prev_reduction = 1 16 | for fi in feature_info: 17 | # sanity check the mandatory fields, there may be additional fields depending on the model 18 | assert 'num_chs' in fi and fi['num_chs'] > 0 19 | assert 'reduction' in fi and fi['reduction'] >= prev_reduction 20 | prev_reduction = fi['reduction'] 21 | assert 'module' in fi 22 | self.out_indices = out_indices 23 | self.info = feature_info 24 | 25 | def from_other(self, out_indices: Tuple[int]): 26 | return FeatureInfo(deepcopy(self.info), out_indices) 27 | 28 | def get(self, key, idx=None): 29 | """ Get value by key at specified index (indices) 30 | if idx == None, returns value for key at each output index 31 | if idx is an integer, return value for that feature module index (ignoring output indices) 32 | if idx is a list/tupple, return value for each module index (ignoring output indices) 33 | """ 34 | if idx is None: 35 | return [self.info[i][key] for i in self.out_indices] 36 | if isinstance(idx, (tuple, list)): 37 | return [self.info[i][key] for i in idx] 38 | else: 39 | return self.info[idx][key] 40 | 41 | def get_dicts(self, keys=None, idx=None): 42 | """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) 43 | """ 44 | if idx is None: 45 | if keys is None: 46 | return [self.info[i] for i in self.out_indices] 47 | else: 48 | return [{k: self.info[i][k] for k in keys} for i in self.out_indices] 49 | if isinstance(idx, (tuple, list)): 50 | return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] 51 | else: 52 | return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} 53 | 54 | def channels(self, idx=None): 55 | """ feature channels accessor 56 | """ 57 | return self.get('num_chs', idx) 58 | 59 | def reduction(self, idx=None): 60 | """ feature reduction (output stride) accessor 61 | """ 62 | return self.get('reduction', idx) 63 | 64 | def module_name(self, idx=None): 65 | """ feature module name accessor 66 | """ 67 | return self.get('module', idx) 68 | 69 | def __getitem__(self, item): 70 | return self.info[item] 71 | 72 | def __len__(self): 73 | return len(self.info) 74 | 75 | 76 | class FeatureHooks: 77 | """ Feature Hook Helper 78 | This module helps with the setup and extraction of hooks for extracting features from 79 | internal nodes in a model by node name. This works quite well in eager Python but needs 80 | redesign for torcscript. 81 | """ 82 | 83 | def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): 84 | # setup feature hooks 85 | modules = {k: v for k, v in named_modules} 86 | for i, h in enumerate(hooks): 87 | hook_name = h['module'] 88 | m = modules[hook_name] 89 | hook_id = out_map[i] if out_map else hook_name 90 | hook_fn = partial(self._collect_output_hook, hook_id) 91 | hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type 92 | if hook_type == 'forward_pre': 93 | m.register_forward_pre_hook(hook_fn) 94 | elif hook_type == 'forward': 95 | m.register_forward_hook(hook_fn) 96 | else: 97 | assert False, "Unsupported hook type" 98 | self._feature_outputs = defaultdict(OrderedDict) 99 | 100 | def _collect_output_hook(self, hook_id, *args): 101 | x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre 102 | if isinstance(x, tuple): 103 | x = x[0] # unwrap input tuple 104 | self._feature_outputs[x.device][hook_id] = x 105 | 106 | def get_output(self, device) -> Dict[str, torch.tensor]: 107 | output = self._feature_outputs[device] 108 | self._feature_outputs[device] = OrderedDict() # clear after reading 109 | return output 110 | 111 | 112 | def _module_list(module, flatten_sequential=False): 113 | # a yield/iter would be better for this but wouldn't be compatible with torchscript 114 | ml = [] 115 | for name, module in module.named_children(): 116 | if flatten_sequential and isinstance(module, nn.Sequential): 117 | # first level of Sequential containers is flattened into containing model 118 | for child_name, child_module in module.named_children(): 119 | combined = [name, child_name] 120 | ml.append(('_'.join(combined), '.'.join(combined), child_module)) 121 | else: 122 | ml.append((name, name, module)) 123 | return ml 124 | 125 | 126 | def _get_feature_info(net, out_indices): 127 | feature_info = getattr(net, 'feature_info') 128 | if isinstance(feature_info, FeatureInfo): 129 | return feature_info.from_other(out_indices) 130 | elif isinstance(feature_info, (list, tuple)): 131 | return FeatureInfo(net.feature_info, out_indices) 132 | else: 133 | assert False, "Provided feature_info is not valid" 134 | 135 | 136 | def _get_return_layers(feature_info, out_map): 137 | module_names = feature_info.module_name() 138 | return_layers = {} 139 | for i, name in enumerate(module_names): 140 | return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] 141 | return return_layers 142 | 143 | 144 | class FeatureDictNet(nn.ModuleDict): 145 | """ Feature extractor with OrderedDict return 146 | Wrap a model and extract features as specified by the out indices, the network is 147 | partially re-built from contained modules. 148 | There is a strong assumption that the modules have been registered into the model in the same 149 | order as they are used. There should be no reuse of the same nn.Module more than once, including 150 | trivial modules like `self.relu = nn.ReLU`. 151 | Only submodules that are directly assigned to the model class (`model.feature1`) or at most 152 | one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. 153 | All Sequential containers that are directly assigned to the original model will have their 154 | modules assigned to this module with the name `model.features.1` being changed to `model.features_1` 155 | Arguments: 156 | model (nn.Module): model from which we will extract the features 157 | out_indices (tuple[int]): model output indices to extract features for 158 | out_map (sequence): list or tuple specifying desired return id for each out index, 159 | otherwise str(index) is used 160 | feature_concat (bool): whether to concatenate intermediate features that are lists or tuples 161 | vs select element [0] 162 | flatten_sequential (bool): whether to flatten sequential modules assigned to model 163 | """ 164 | def __init__( 165 | self, model, 166 | out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): 167 | super(FeatureDictNet, self).__init__() 168 | self.feature_info = _get_feature_info(model, out_indices) 169 | self.concat = feature_concat 170 | self.return_layers = {} 171 | return_layers = _get_return_layers(self.feature_info, out_map) 172 | modules = _module_list(model, flatten_sequential=flatten_sequential) 173 | remaining = set(return_layers.keys()) 174 | layers = OrderedDict() 175 | for new_name, old_name, module in modules: 176 | layers[new_name] = module 177 | if old_name in remaining: 178 | # return id has to be consistently str type for torchscript 179 | self.return_layers[new_name] = str(return_layers[old_name]) 180 | remaining.remove(old_name) 181 | if not remaining: 182 | break 183 | assert not remaining and len(self.return_layers) == len(return_layers), \ 184 | f'Return layers ({remaining}) are not present in model' 185 | self.update(layers) 186 | 187 | def _collect(self, x) -> (Dict[str, torch.Tensor]): 188 | out = OrderedDict() 189 | for name, module in self.items(): 190 | x = module(x) 191 | if name in self.return_layers: 192 | out_id = self.return_layers[name] 193 | if isinstance(x, (tuple, list)): 194 | # If model tap is a tuple or list, concat or select first element 195 | # FIXME this may need to be more generic / flexible for some nets 196 | out[out_id] = torch.cat(x, 1) if self.concat else x[0] 197 | else: 198 | out[out_id] = x 199 | return out 200 | 201 | def forward(self, x) -> Dict[str, torch.Tensor]: 202 | return self._collect(x) 203 | 204 | 205 | class FeatureListNet(FeatureDictNet): 206 | """ Feature extractor with list return 207 | See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. 208 | In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. 209 | """ 210 | def __init__( 211 | self, model, 212 | out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): 213 | super(FeatureListNet, self).__init__( 214 | model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, 215 | flatten_sequential=flatten_sequential) 216 | 217 | def forward(self, x) -> (List[torch.Tensor]): 218 | return list(self._collect(x).values()) 219 | 220 | 221 | class FeatureHookNet(nn.ModuleDict): 222 | """ FeatureHookNet 223 | Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. 224 | If `no_rewrite` is True, features are extracted via hooks without modifying the underlying 225 | network in any way. 226 | If `no_rewrite` is False, the model will be re-written as in the 227 | FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. 228 | FIXME this does not currently work with Torchscript, see FeatureHooks class 229 | """ 230 | def __init__( 231 | self, model, 232 | out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, 233 | feature_concat=False, flatten_sequential=False, default_hook_type='forward'): 234 | super(FeatureHookNet, self).__init__() 235 | assert not torch.jit.is_scripting() 236 | self.feature_info = _get_feature_info(model, out_indices) 237 | self.out_as_dict = out_as_dict 238 | layers = OrderedDict() 239 | hooks = [] 240 | if no_rewrite: 241 | assert not flatten_sequential 242 | if hasattr(model, 'reset_classifier'): # make sure classifier is removed? 243 | model.reset_classifier(0) 244 | layers['body'] = model 245 | hooks.extend(self.feature_info.get_dicts()) 246 | else: 247 | modules = _module_list(model, flatten_sequential=flatten_sequential) 248 | remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type 249 | for f in self.feature_info.get_dicts()} 250 | for new_name, old_name, module in modules: 251 | layers[new_name] = module 252 | for fn, fm in module.named_modules(prefix=old_name): 253 | if fn in remaining: 254 | hooks.append(dict(module=fn, hook_type=remaining[fn])) 255 | del remaining[fn] 256 | if not remaining: 257 | break 258 | assert not remaining, f'Return layers ({remaining}) are not present in model' 259 | self.update(layers) 260 | self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) 261 | 262 | def forward(self, x): 263 | for name, module in self.items(): 264 | x = module(x) 265 | out = self.hooks.get_output(x.device) 266 | return out if self.out_as_dict else list(out.values()) 267 | -------------------------------------------------------------------------------- /src/modeling/timesformer/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Copyright 2020 Ross Wightman 3 | # Modified model creation / weight loading / state_dict helpers 4 | 5 | import logging 6 | import os 7 | import sys 8 | import math 9 | from collections import OrderedDict 10 | from copy import deepcopy 11 | from typing import Callable 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.utils.model_zoo as model_zoo 16 | import torch.nn.functional as F 17 | 18 | from src.modeling.timesformer.features import FeatureListNet, FeatureDictNet, FeatureHookNet 19 | from src.modeling.timesformer.conv2d_same import Conv2dSame 20 | from src.modeling.timesformer.linear import Linear 21 | 22 | from horovod import torch as hvd 23 | 24 | _logger = logging.getLogger() 25 | 26 | def load_state_dict(checkpoint_path, use_ema=False): 27 | if checkpoint_path and os.path.isfile(checkpoint_path): 28 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 29 | state_dict_key = 'state_dict' 30 | if isinstance(checkpoint, dict): 31 | if use_ema and 'state_dict_ema' in checkpoint: 32 | state_dict_key = 'state_dict_ema' 33 | if state_dict_key and state_dict_key in checkpoint: 34 | new_state_dict = OrderedDict() 35 | for k, v in checkpoint[state_dict_key].items(): 36 | # strip `module.` prefix 37 | name = k[7:] if k.startswith('module') else k 38 | new_state_dict[name] = v 39 | state_dict = new_state_dict 40 | elif 'model_state' in checkpoint: 41 | state_dict_key = 'model_state' 42 | new_state_dict = OrderedDict() 43 | for k, v in checkpoint[state_dict_key].items(): 44 | # strip `model.` prefix 45 | name = k[6:] if k.startswith('model') else k 46 | new_state_dict[name] = v 47 | state_dict = new_state_dict 48 | else: 49 | state_dict = checkpoint 50 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 51 | return state_dict 52 | else: 53 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 54 | raise FileNotFoundError() 55 | 56 | 57 | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): 58 | state_dict = load_state_dict(checkpoint_path, use_ema) 59 | model.load_state_dict(state_dict, strict=strict) 60 | 61 | 62 | # def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): 63 | # resume_epoch = None 64 | # if os.path.isfile(checkpoint_path): 65 | # checkpoint = torch.load(checkpoint_path, map_location='cpu') 66 | # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 67 | # if log_info: 68 | # _logger.info('Restoring model state from checkpoint...') 69 | # new_state_dict = OrderedDict() 70 | # for k, v in checkpoint['state_dict'].items(): 71 | # name = k[7:] if k.startswith('module') else k 72 | # new_state_dict[name] = v 73 | # model.load_state_dict(new_state_dict) 74 | 75 | # if optimizer is not None and 'optimizer' in checkpoint: 76 | # if log_info: 77 | # _logger.info('Restoring optimizer state from checkpoint...') 78 | # optimizer.load_state_dict(checkpoint['optimizer']) 79 | 80 | # if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: 81 | # if log_info: 82 | # _logger.info('Restoring AMP loss scaler state from checkpoint...') 83 | # loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) 84 | 85 | # if 'epoch' in checkpoint: 86 | # resume_epoch = checkpoint['epoch'] 87 | # if 'version' in checkpoint and checkpoint['version'] > 1: 88 | # resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save 89 | 90 | # if log_info: 91 | # _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) 92 | # else: 93 | # model.load_state_dict(checkpoint) 94 | # if log_info: 95 | # _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) 96 | # return resume_epoch 97 | # else: 98 | # _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 99 | # raise FileNotFoundError() 100 | 101 | 102 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True): 103 | if cfg is None: 104 | cfg = getattr(model, 'default_cfg') 105 | if cfg is None or 'url' not in cfg or not cfg['url']: 106 | _logger.warning("Pretrained model URL is invalid, using random initialization.") 107 | return 108 | 109 | if len(pretrained_model) == 0: 110 | if cfg is None: 111 | _logger.info(f"loading from default config {model.default_cfg}.") 112 | state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') 113 | else: 114 | try: 115 | state_dict = load_state_dict(pretrained_model)['model'] 116 | except: 117 | state_dict = load_state_dict(pretrained_model) 118 | 119 | 120 | if filter_fn is not None: 121 | state_dict = filter_fn(state_dict) 122 | 123 | if in_chans == 1: 124 | conv1_name = cfg['first_conv'] 125 | _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) 126 | conv1_weight = state_dict[conv1_name + '.weight'] 127 | conv1_type = conv1_weight.dtype 128 | conv1_weight = conv1_weight.float() 129 | O, I, J, K = conv1_weight.shape 130 | if I > 3: 131 | assert conv1_weight.shape[1] % 3 == 0 132 | # For models with space2depth stems 133 | conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) 134 | conv1_weight = conv1_weight.sum(dim=2, keepdim=False) 135 | else: 136 | conv1_weight = conv1_weight.sum(dim=1, keepdim=True) 137 | conv1_weight = conv1_weight.to(conv1_type) 138 | state_dict[conv1_name + '.weight'] = conv1_weight 139 | elif in_chans != 3: 140 | conv1_name = cfg['first_conv'] 141 | conv1_weight = state_dict[conv1_name + '.weight'] 142 | conv1_type = conv1_weight.dtype 143 | conv1_weight = conv1_weight.float() 144 | O, I, J, K = conv1_weight.shape 145 | if I != 3: 146 | _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) 147 | del state_dict[conv1_name + '.weight'] 148 | strict = False 149 | else: 150 | _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) 151 | repeat = int(math.ceil(in_chans / 3)) 152 | conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 153 | conv1_weight *= (3 / float(in_chans)) 154 | conv1_weight = conv1_weight.to(conv1_type) 155 | state_dict[conv1_name + '.weight'] = conv1_weight 156 | 157 | 158 | classifier_name = cfg['classifier'] 159 | if num_classes == 1000 and cfg['num_classes'] == 1001: 160 | # special case for imagenet trained models with extra background class in pretrained weights 161 | classifier_weight = state_dict[classifier_name + '.weight'] 162 | state_dict[classifier_name + '.weight'] = classifier_weight[1:] 163 | classifier_bias = state_dict[classifier_name + '.bias'] 164 | state_dict[classifier_name + '.bias'] = classifier_bias[1:] 165 | elif num_classes != state_dict[classifier_name + '.weight'].size(0): 166 | #print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True) 167 | # completely discard fully connected for all other differences between pretrained and created model 168 | del state_dict[classifier_name + '.weight'] 169 | del state_dict[classifier_name + '.bias'] 170 | strict = False 171 | 172 | 173 | ## Resizing the positional embeddings in case they don't match 174 | _logger.info(f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}") 175 | if num_patches + 1 != state_dict['pos_embed'].size(1): 176 | pos_embed = state_dict['pos_embed'] 177 | cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1) 178 | other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2) 179 | new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest') 180 | new_pos_embed = new_pos_embed.transpose(1, 2) 181 | new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) 182 | state_dict['pos_embed'] = new_pos_embed 183 | 184 | ## Resizing time embeddings in case they don't match 185 | if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1): 186 | _logger.info(f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}") 187 | time_embed = state_dict['time_embed'].transpose(1, 2) 188 | new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest') 189 | state_dict['time_embed'] = new_time_embed.transpose(1, 2) 190 | 191 | ## Initializing temporal attention 192 | if attention_type == 'divided_space_time': 193 | new_state_dict = state_dict.copy() 194 | for key in state_dict: 195 | if 'blocks' in key and 'attn' in key: 196 | new_key = key.replace('attn','temporal_attn') 197 | if not new_key in state_dict: 198 | new_state_dict[new_key] = state_dict[key] 199 | else: 200 | new_state_dict[new_key] = state_dict[new_key] 201 | if 'blocks' in key and 'norm1' in key: 202 | new_key = key.replace('norm1','temporal_norm1') 203 | if not new_key in state_dict: 204 | new_state_dict[new_key] = state_dict[key] 205 | else: 206 | new_state_dict[new_key] = state_dict[new_key] 207 | state_dict = new_state_dict 208 | 209 | ## Loading the weights 210 | model.load_state_dict(state_dict, strict=False) 211 | 212 | 213 | def load_pretrained_CLIP_ViT(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs): 214 | if hvd.rank() == 0: 215 | _logger.info(f"Loading CLIP ViT-B/16 checkpoints.") 216 | loaded_state_dict = torch.load(pretrained_model) 217 | 218 | ## Initializing temporal attention 219 | new_state_dict = loaded_state_dict.copy() 220 | for key in loaded_state_dict: 221 | if 'blocks' in key and 'attn' in key: 222 | new_key = key.replace('attn','temporal_attn') 223 | if not new_key in loaded_state_dict: 224 | new_state_dict[new_key] = loaded_state_dict[key] 225 | else: 226 | new_state_dict[new_key] = loaded_state_dict[new_key] 227 | if 'blocks' in key and 'norm1' in key: 228 | new_key = key.replace('norm1','temporal_norm1') 229 | if not new_key in loaded_state_dict: 230 | new_state_dict[new_key] = loaded_state_dict[key] 231 | else: 232 | new_state_dict[new_key] = loaded_state_dict[new_key] 233 | 234 | loaded_state_dict = new_state_dict 235 | 236 | loaded_keys = loaded_state_dict.keys() 237 | model_keys = model.state_dict().keys() 238 | 239 | load_not_in_model = [k for k in loaded_keys if k not in model_keys] 240 | model_not_in_load = [k for k in model_keys if k not in loaded_keys] 241 | 242 | toload = dict() 243 | mismatched_shape_keys = [] 244 | for k in model_keys: 245 | if k in loaded_keys: 246 | if model.state_dict()[k].shape != loaded_state_dict[k].shape: 247 | mismatched_shape_keys.append(k) 248 | else: 249 | toload[k] = loaded_state_dict[k] 250 | 251 | if hvd.rank() == 0: 252 | _logger.info("Keys in loaded but not in model:") 253 | _logger.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}") 254 | _logger.info("Keys in model but not in loaded:") 255 | _logger.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}") 256 | _logger.info("Keys in model and loaded, but shape mismatched:") 257 | _logger.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}") 258 | 259 | model.load_state_dict(toload, strict=False) 260 | 261 | 262 | def load_pretrained_imagenet(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs): 263 | import timm 264 | 265 | if hvd.rank() == 0: 266 | _logger.info(f"Loading vit_base_patch16_224 checkpoints.") 267 | loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True).state_dict() 268 | 269 | del loaded_state_dict['head.weight'] 270 | del loaded_state_dict['head.bias'] 271 | 272 | ## Initializing temporal attention 273 | new_state_dict = loaded_state_dict.copy() 274 | for key in loaded_state_dict: 275 | if 'blocks' in key and 'attn' in key: 276 | new_key = key.replace('attn','temporal_attn') 277 | if not new_key in loaded_state_dict: 278 | new_state_dict[new_key] = loaded_state_dict[key] 279 | else: 280 | new_state_dict[new_key] = loaded_state_dict[new_key] 281 | if 'blocks' in key and 'norm1' in key: 282 | new_key = key.replace('norm1','temporal_norm1') 283 | if not new_key in loaded_state_dict: 284 | new_state_dict[new_key] = loaded_state_dict[key] 285 | else: 286 | new_state_dict[new_key] = loaded_state_dict[new_key] 287 | 288 | loaded_state_dict = new_state_dict 289 | 290 | loaded_keys = loaded_state_dict.keys() 291 | model_keys = model.state_dict().keys() 292 | 293 | load_not_in_model = [k for k in loaded_keys if k not in model_keys] 294 | model_not_in_load = [k for k in model_keys if k not in loaded_keys] 295 | 296 | toload = dict() 297 | mismatched_shape_keys = [] 298 | for k in model_keys: 299 | if k in loaded_keys: 300 | if model.state_dict()[k].shape != loaded_state_dict[k].shape: 301 | mismatched_shape_keys.append(k) 302 | else: 303 | toload[k] = loaded_state_dict[k] 304 | 305 | if hvd.rank() == 0: 306 | _logger.info("Keys in loaded but not in model:") 307 | _logger.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}") 308 | _logger.info("Keys in model but not in loaded:") 309 | _logger.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}") 310 | _logger.info("Keys in model and loaded, but shape mismatched:") 311 | _logger.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}") 312 | 313 | model.load_state_dict(toload, strict=False) 314 | 315 | def load_pretrained_kinetics(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs): 316 | if cfg is None: 317 | cfg = getattr(model, 'default_cfg') 318 | if cfg is None or 'url' not in cfg or not cfg['url']: 319 | _logger.warning("Pretrained model URL is invalid, using random initialization.") 320 | return 321 | 322 | assert len(pretrained_model) > 0, "Path to pre-trained Kinetics weights not provided." 323 | 324 | state_dict = load_state_dict(pretrained_model) 325 | 326 | classifier_name = cfg['classifier'] 327 | if ignore_classifier: 328 | 329 | classifier_weight_key = classifier_name + '.weight' 330 | classifier_bias_key = classifier_name + '.bias' 331 | 332 | state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key] 333 | state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key] 334 | 335 | else: 336 | raise NotImplementedError('[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier.') 337 | 338 | ## Resizing the positional embeddings in case they don't match 339 | if num_patches + 1 != state_dict['pos_embed'].size(1): 340 | new_pos_embed = resize_spatial_embedding(state_dict, 'pos_embed', num_patches) 341 | state_dict['pos_embed'] = new_pos_embed 342 | 343 | ## Resizing time embeddings in case they don't match 344 | if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1): 345 | state_dict['time_embed'] = resize_temporal_embedding(state_dict, 'time_embed', num_frames) 346 | 347 | ## Loading the weights 348 | try: 349 | model.load_state_dict(state_dict, strict=True) 350 | _logger.info('Succeeded in loading Kinetics pre-trained weights.') 351 | except: 352 | _logger.error('Error in loading Kinetics pre-trained weights.') 353 | 354 | 355 | def resize_spatial_embedding(state_dict, key, num_patches): 356 | _logger.info(f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}") 357 | 358 | pos_embed = state_dict[key] 359 | 360 | cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1) 361 | other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2) 362 | 363 | new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest') 364 | new_pos_embed = new_pos_embed.transpose(1, 2) 365 | new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) 366 | 367 | return new_pos_embed 368 | 369 | 370 | def resize_temporal_embedding(state_dict, key, num_frames): 371 | _logger.info(f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}") 372 | 373 | time_embed = state_dict[key].transpose(1, 2) 374 | new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest') 375 | 376 | return new_time_embed.transpose(1, 2) -------------------------------------------------------------------------------- /src/modeling/timesformer/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | class Linear(nn.Linear): 8 | def forward(self, input: torch.Tensor) -> torch.Tensor: 9 | if torch.jit.is_scripting(): 10 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 11 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 12 | else: 13 | return F.linear(input, self.weight, self.bias) 14 | -------------------------------------------------------------------------------- /src/modeling/timesformer/operators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | # """Custom operators.""" 4 | 5 | # import torch 6 | # import torch.nn as nn 7 | 8 | 9 | # class Swish(nn.Module): 10 | # """Swish activation function: x * sigmoid(x).""" 11 | 12 | # def __init__(self): 13 | # super(Swish, self).__init__() 14 | 15 | # def forward(self, x): 16 | # return SwishEfficient.apply(x) 17 | 18 | 19 | # class SwishEfficient(torch.autograd.Function): 20 | # """Swish activation function: x * sigmoid(x).""" 21 | 22 | # @staticmethod 23 | # def forward(ctx, x): 24 | # result = x * torch.sigmoid(x) 25 | # ctx.save_for_backward(x) 26 | # return result 27 | 28 | # @staticmethod 29 | # def backward(ctx, grad_output): 30 | # x = ctx.saved_variables[0] 31 | # sigmoid_x = torch.sigmoid(x) 32 | # return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) 33 | 34 | 35 | # class SE(nn.Module): 36 | # """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" 37 | 38 | # def _round_width(self, width, multiplier, min_width=8, divisor=8): 39 | # """ 40 | # Round width of filters based on width multiplier 41 | # Args: 42 | # width (int): the channel dimensions of the input. 43 | # multiplier (float): the multiplication factor. 44 | # min_width (int): the minimum width after multiplication. 45 | # divisor (int): the new width should be dividable by divisor. 46 | # """ 47 | # if not multiplier: 48 | # return width 49 | 50 | # width *= multiplier 51 | # min_width = min_width or divisor 52 | # width_out = max( 53 | # min_width, int(width + divisor / 2) // divisor * divisor 54 | # ) 55 | # if width_out < 0.9 * width: 56 | # width_out += divisor 57 | # return int(width_out) 58 | 59 | # def __init__(self, dim_in, ratio, relu_act=True): 60 | # """ 61 | # Args: 62 | # dim_in (int): the channel dimensions of the input. 63 | # ratio (float): the channel reduction ratio for squeeze. 64 | # relu_act (bool): whether to use ReLU activation instead 65 | # of Swish (default). 66 | # divisor (int): the new width should be dividable by divisor. 67 | # """ 68 | # super(SE, self).__init__() 69 | # self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 70 | # dim_fc = self._round_width(dim_in, ratio) 71 | # self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) 72 | # self.fc1_act = nn.ReLU() if relu_act else Swish() 73 | # self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) 74 | 75 | # self.fc2_sig = nn.Sigmoid() 76 | 77 | # def forward(self, x): 78 | # x_in = x 79 | # for module in self.children(): 80 | # x = module(x) 81 | # return x_in * x 82 | -------------------------------------------------------------------------------- /src/modeling/timesformer/vit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Various utility functions 3 | 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | import math 8 | import warnings 9 | import torch.nn.functional as F 10 | 11 | from src.modeling.timesformer.helpers import load_pretrained 12 | from itertools import repeat 13 | import collections.abc as container_abcs 14 | 15 | DEFAULT_CROP_PCT = 0.875 16 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 17 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 18 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 19 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 20 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 21 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 22 | 23 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 24 | def norm_cdf(x): 25 | # Computes standard normal cumulative distribution function 26 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 27 | 28 | if (mean < a - 2 * std) or (mean > b + 2 * std): 29 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 30 | "The distribution of values may be incorrect.", 31 | stacklevel=2) 32 | 33 | with torch.no_grad(): 34 | # Values are generated by using a truncated uniform distribution and 35 | # then using the inverse CDF for the normal distribution. 36 | # Get upper and lower cdf values 37 | l = norm_cdf((a - mean) / std) 38 | u = norm_cdf((b - mean) / std) 39 | 40 | # Uniformly fill tensor with values from [l, u], then translate to 41 | # [2l-1, 2u-1]. 42 | tensor.uniform_(2 * l - 1, 2 * u - 1) 43 | 44 | # Use inverse cdf transform for normal distribution to get truncated 45 | # standard normal 46 | tensor.erfinv_() 47 | 48 | # Transform to proper mean, std 49 | tensor.mul_(std * math.sqrt(2.)) 50 | tensor.add_(mean) 51 | 52 | # Clamp to ensure it's in the proper range 53 | tensor.clamp_(min=a, max=b) 54 | return tensor 55 | 56 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 57 | # type: (Tensor, float, float, float, float) -> Tensor 58 | r"""Fills the input Tensor with values drawn from a truncated 59 | normal distribution. The values are effectively drawn from the 60 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 61 | with values outside :math:`[a, b]` redrawn until they are within 62 | the bounds. The method used for generating the random values works 63 | best when :math:`a \leq \text{mean} \leq b`. 64 | Args: 65 | tensor: an n-dimensional `torch.Tensor` 66 | mean: the mean of the normal distribution 67 | std: the standard deviation of the normal distribution 68 | a: the minimum cutoff value 69 | b: the maximum cutoff value 70 | Examples: 71 | >>> w = torch.empty(3, 5) 72 | >>> nn.init.trunc_normal_(w) 73 | """ 74 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 75 | 76 | # From PyTorch internals 77 | def _ntuple(n): 78 | def parse(x): 79 | if isinstance(x, container_abcs.Iterable): 80 | return x 81 | return tuple(repeat(x, n)) 82 | return parse 83 | to_2tuple = _ntuple(2) 84 | 85 | # Calculate symmetric padding for a convolution 86 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 87 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 88 | return padding 89 | 90 | def get_padding_value(padding, kernel_size, **kwargs): 91 | dynamic = False 92 | if isinstance(padding, str): 93 | # for any string padding, the padding will be calculated for you, one of three ways 94 | padding = padding.lower() 95 | if padding == 'same': 96 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 97 | if is_static_pad(kernel_size, **kwargs): 98 | # static case, no extra overhead 99 | padding = get_padding(kernel_size, **kwargs) 100 | else: 101 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 102 | padding = 0 103 | dynamic = True 104 | elif padding == 'valid': 105 | # 'VALID' padding, same as padding=0 106 | padding = 0 107 | else: 108 | # Default to PyTorch style 'same'-ish symmetric padding 109 | padding = get_padding(kernel_size, **kwargs) 110 | return padding, dynamic 111 | 112 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 113 | def get_same_padding(x: int, k: int, s: int, d: int): 114 | return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0) 115 | 116 | 117 | # Can SAME padding for given args be done statically? 118 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 119 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 120 | 121 | 122 | # Dynamically pad input x with 'SAME' padding for conv with specified args 123 | #def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 124 | def pad_same(x, k, s, d=(1, 1), value= 0): 125 | ih, iw = x.size()[-2:] 126 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 127 | if pad_h > 0 or pad_w > 0: 128 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 129 | return x 130 | 131 | def adaptive_pool_feat_mult(pool_type='avg'): 132 | if pool_type == 'catavgmax': 133 | return 2 134 | else: 135 | return 1 136 | 137 | def drop_path(x, drop_prob: float = 0., training: bool = False): 138 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 139 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 140 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 141 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 142 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 143 | 'survival rate' as the argument. 144 | """ 145 | if drop_prob == 0. or not training: 146 | return x 147 | keep_prob = 1 - drop_prob 148 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 149 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 150 | random_tensor.floor_() # binarize 151 | output = x.div(keep_prob) * random_tensor 152 | return output 153 | 154 | class DropPath(nn.Module): 155 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 156 | """ 157 | def __init__(self, drop_prob=None): 158 | super(DropPath, self).__init__() 159 | self.drop_prob = drop_prob 160 | 161 | def forward(self, x): 162 | return drop_path(x, self.drop_prob, self.training) 163 | -------------------------------------------------------------------------------- /src/optimization/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | copied from hugginface 4 | """ 5 | import math 6 | 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class AdamW(Optimizer): 12 | """ Implements Adam algorithm with weight decay fix. 13 | Parameters: 14 | lr (float): learning rate. Default 1e-3. 15 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). 16 | Default: (0.9, 0.999) 17 | eps (float): Adams epsilon. Default: 1e-6 18 | weight_decay (float): Weight decay. Default: 0.0 19 | correct_bias (bool): can be set to False to avoid correcting bias 20 | in Adam (e.g. like in Bert TF repository). Default True. 21 | """ 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 23 | weight_decay=0.0, correct_bias=True): 24 | if lr < 0.0: 25 | raise ValueError( 26 | "Invalid learning rate: {} - should be >= 0.0".format(lr)) 27 | if not 0.0 <= betas[0] < 1.0: 28 | raise ValueError("Invalid beta parameter: {} - " 29 | "should be in [0.0, 1.0[".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter: {} - " 32 | "should be in [0.0, 1.0[".format(betas[1])) 33 | if not 0.0 <= eps: 34 | raise ValueError("Invalid epsilon value: {} - " 35 | "should be >= 0.0".format(eps)) 36 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 37 | correct_bias=correct_bias) 38 | super(AdamW, self).__init__(params, defaults) 39 | 40 | def step(self, closure=None): 41 | """Performs a single optimization step. 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | grad = p.grad.data 55 | if grad.is_sparse: 56 | raise RuntimeError( 57 | 'Adam does not support sparse ' 58 | 'gradients, please consider SparseAdam instead') 59 | 60 | state = self.state[p] 61 | 62 | # State initialization 63 | if len(state) == 0: 64 | state['step'] = 0 65 | # Exponential moving average of gradient values 66 | state['exp_avg'] = torch.zeros_like(p.data) 67 | # Exponential moving average of squared gradient values 68 | state['exp_avg_sq'] = torch.zeros_like(p.data) 69 | 70 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 71 | beta1, beta2 = group['betas'] 72 | 73 | state['step'] += 1 74 | 75 | # Decay the first and second moment running average coefficient 76 | # In-place operations to update the averages at the same time 77 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 78 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 79 | denom = exp_avg_sq.sqrt().add_(group['eps']) 80 | 81 | step_size = group['lr'] 82 | if group['correct_bias']: # No bias correction for Bert 83 | bias_correction1 = 1.0 - beta1 ** state['step'] 84 | bias_correction2 = 1.0 - beta2 ** state['step'] 85 | step_size = (step_size * math.sqrt(bias_correction2) 86 | / bias_correction1) 87 | 88 | p.data.addcdiv_(-step_size, exp_avg, denom) 89 | 90 | # Just adding the square of the weights to the loss function is 91 | # *not* the correct way of using L2 regularization/weight decay 92 | # with Adam, since that will interact with the m and v 93 | # parameters in strange ways. 94 | # 95 | # Instead we want to decay the weights in a manner that doesn't 96 | # interact with the m/v parameters. This is equivalent to 97 | # adding the square of the weights to the loss with plain 98 | # (non-momentum) SGD. 99 | # Add weight decay at the end (fixed version) 100 | if group['weight_decay'] > 0.0: 101 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 102 | 103 | return loss 104 | -------------------------------------------------------------------------------- /src/optimization/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | optimizer learning rate scheduling helpers 3 | """ 4 | from math import ceil 5 | from collections import Counter 6 | 7 | 8 | def noam_schedule(step, warmup_step=4000): 9 | if step <= warmup_step: 10 | return step / warmup_step 11 | return (warmup_step ** 0.5) * (step ** -0.5) 12 | 13 | 14 | def warmup_linear(step, warmup_step, tot_step): 15 | if step < warmup_step: 16 | return step / warmup_step 17 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 18 | 19 | 20 | def multi_step_schedule(n_epoch, milestones, gamma=0.5): 21 | milestones = list(sorted(milestones)) 22 | for i, m in enumerate(milestones): 23 | if n_epoch < m: 24 | return gamma**i 25 | return gamma**(len(milestones)+1) 26 | 27 | 28 | def get_lr_sched(global_step, decay, learning_rate, 29 | num_train_steps, warmup_ratio=0.1, 30 | decay_epochs=[], multi_step_epoch=-1): 31 | warmup_steps = int(warmup_ratio*num_train_steps) 32 | if decay == 'linear': 33 | lr_this_step = learning_rate * warmup_linear( 34 | global_step, warmup_steps, num_train_steps) 35 | elif decay == 'invsqrt': 36 | lr_this_step = learning_rate * noam_schedule( 37 | global_step, warmup_steps) 38 | elif decay == 'constant': 39 | lr_this_step = learning_rate 40 | elif decay == "multi_step": 41 | assert multi_step_epoch >= 0 42 | lr_this_step = learning_rate * multi_step_schedule( 43 | multi_step_epoch, decay_epochs) 44 | if lr_this_step <= 0: 45 | # save guard for possible miscalculation of train steps 46 | lr_this_step = 1e-8 47 | return lr_this_step 48 | -------------------------------------------------------------------------------- /src/optimization/utils.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam, Adamax, SGD 2 | from src.optimization.adamw import AdamW 3 | 4 | 5 | def setup_e2e_optimizer(model, opts): 6 | if opts.optim == 'adam': 7 | OptimCls = Adam 8 | elif opts.optim == 'adamax': 9 | OptimCls = Adamax 10 | elif opts.optim == 'adamw': 11 | OptimCls = AdamW 12 | else: 13 | raise ValueError('invalid optimizer') 14 | optimizer = OptimCls(model.parameters(), lr=opts.learning_rate, betas=opts.betas) 15 | 16 | return optimizer 17 | -------------------------------------------------------------------------------- /src/pretrain/run_pretrain_contrastive_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import time 5 | import random 6 | import pprint 7 | import math 8 | import json 9 | from transformers import BertConfig, BertTokenizerFast 10 | 11 | from src.datasets.dataset_pretrain_sparse import AlproPretrainSparseDataset, PretrainImageTextDataset, PretrainCollator 12 | from src.datasets.dataloader import MetaLoader, PrefetchLoader 13 | from src.datasets.data_utils import ImageNorm, mk_input_group 14 | from torch.utils.data import DataLoader 15 | from torch.nn.utils import clip_grad_norm_ 16 | from src.configs.config import shared_configs 17 | from src.utils.misc import set_random_seed, NoOp, zero_none_grad 18 | from src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter 19 | from src.utils.basic_utils import load_jsonl, load_json, read_dataframe 20 | from src.utils.load_save import (ModelSaver, 21 | save_training_meta, 22 | load_state_dict_with_pos_embed_resizing) 23 | from src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer 24 | from src.optimization.sched import get_lr_sched 25 | from src.optimization.utils import setup_e2e_optimizer 26 | from collections import defaultdict 27 | from tqdm import tqdm 28 | from os.path import join 29 | from apex import amp 30 | from torch.utils.data.distributed import DistributedSampler 31 | import horovod.torch as hvd 32 | from src.utils.distributed import all_gather_list 33 | 34 | from src.modeling.alpro_models import Prompter 35 | 36 | 37 | def mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, txt_dir, cfg, tokenizer, 38 | is_train=True, max_txt_len=80): 39 | # make a list(dict), where each dict {vis_id: int, txt: str} 40 | if dataset_name == "webvid2m": 41 | datalist = read_dataframe(anno_path) 42 | 43 | datalist = datalist[datalist['txt_len'] < max_txt_len] 44 | LOGGER.info('Found {} entries for webvid2m'.format(len(datalist))) 45 | 46 | elif dataset_name == "cc3m": 47 | datalist = json.load(open(anno_path)) 48 | LOGGER.info('Found {} entries for cc3m'.format(len(datalist))) 49 | 50 | else: 51 | raise ValueError("Invalid dataset_name") 52 | 53 | if dataset_name in ["webvid2m"]: 54 | frm_sampling_strategy = cfg.frm_sampling_strategy 55 | if not is_train and frm_sampling_strategy == "rand": 56 | frm_sampling_strategy = "uniform" 57 | dataset = AlproPretrainSparseDataset( 58 | datalist=datalist, 59 | tokenizer=tokenizer, 60 | img_lmdb_dir=video_dir, 61 | img_db_type='rawvideo', 62 | txt_dir=txt_dir, 63 | crop_size=cfg.crop_img_size, 64 | resize_size=cfg.resize_size, 65 | max_txt_len=cfg.max_txt_len, 66 | use_itm=cfg.use_itm, 67 | fps=cfg.fps, 68 | num_frm=cfg.num_frm, 69 | frm_sampling_strategy=frm_sampling_strategy, 70 | is_train=is_train 71 | # vis_format=vis_format 72 | ) 73 | elif dataset_name in ["cc3m"]: 74 | dataset = PretrainImageTextDataset(datalist=datalist, 75 | tokenizer=tokenizer, 76 | crop_size=cfg.crop_img_size, 77 | resize_size=cfg.resize_size, 78 | max_txt_len=cfg.max_txt_len, 79 | num_frm=cfg.num_frm 80 | ) 81 | 82 | LOGGER.info(f"[{dataset_name}] is_train {is_train} " 83 | f"dataset size {len(dataset)}, ") 84 | batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size 85 | sampler = DistributedSampler( 86 | dataset, num_replicas=hvd.size(), rank=hvd.rank(), 87 | shuffle=is_train) 88 | data_collator = PretrainCollator(tokenizer=tokenizer, 89 | mlm=cfg.use_mlm, 90 | mlm_probability=0.15, 91 | max_length=cfg.max_txt_len, 92 | is_train=is_train) 93 | dataloader = DataLoader(dataset, 94 | batch_size=batch_size, 95 | shuffle=False, 96 | sampler=sampler, 97 | num_workers=cfg.n_workers, 98 | pin_memory=cfg.pin_mem, 99 | collate_fn=data_collator.collate_batch) 100 | 101 | return dataloader 102 | 103 | 104 | def setup_dataloaders(cfg, tokenizer): 105 | LOGGER.info("Init. train_loader and val_loader...") 106 | 107 | train_loaders = {} 108 | for db in cfg.train_datasets: 109 | train_loaders[db.name] = mk_captions_pretrain_dataloader( 110 | dataset_name=db.name, 111 | anno_path=db.ann, video_dir=db.img, txt_dir=db.txt, 112 | cfg=cfg, tokenizer=tokenizer, is_train=True 113 | ) 114 | 115 | val_loaders = {} 116 | for db in cfg.val_datasets: 117 | val_loaders[db.name] = mk_captions_pretrain_dataloader( 118 | dataset_name=db.name, 119 | anno_path=db.ann, video_dir=db.img, txt_dir=db.txt, 120 | cfg=cfg, tokenizer=tokenizer, is_train=False 121 | ) 122 | return train_loaders, val_loaders 123 | 124 | 125 | def setup_model(cfg, device=None): 126 | LOGGER.info("Setup model...") 127 | # has to be a BertConfig instance 128 | model_cfg = load_json(cfg.model_config) 129 | model_cfg = BertConfig(**model_cfg) 130 | # add model-specific config 131 | add_attr_list = [ 132 | "max_n_example_per_group", 133 | "num_entities" 134 | ] 135 | for k in add_attr_list: 136 | setattr(model_cfg, k, cfg[k]) 137 | LOGGER.info(f"model_cfg {pprint.pformat(model_cfg.to_dict())}") 138 | 139 | LOGGER.info("setup e2e model") 140 | 141 | if cfg.model_type == 'pretrain': 142 | # initialize cnn config 143 | video_enc_cfg = load_json(cfg.visual_model_cfg) 144 | 145 | video_enc_cfg['num_frm'] = cfg.num_frm 146 | video_enc_cfg['img_size'] = cfg.crop_img_size 147 | 148 | model = Prompter( 149 | model_cfg, 150 | input_format=cfg.img_input_format, 151 | video_enc_cfg=video_enc_cfg 152 | ) 153 | if cfg.e2e_weights_path: 154 | LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}") 155 | num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2 156 | # NOTE strict if False if loaded from ALBEF ckpt 157 | load_state_dict_with_pos_embed_resizing(model, 158 | cfg.e2e_weights_path, 159 | num_patches=num_patches, 160 | num_frames=cfg.num_frm, 161 | strict=not cfg.albef_init 162 | ) 163 | 164 | else: 165 | LOGGER.info(f"Loading visual weights from {cfg.visual_weights_path}") 166 | LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}") 167 | model.load_separate_ckpt( 168 | visual_weights_path=cfg.visual_weights_path, 169 | bert_weights_path=cfg.bert_weights_path 170 | ) 171 | else: 172 | raise NotImplementedError(f"cfg.model_type not found {cfg.model_type}.") 173 | 174 | # if cfg.freeze_cnn: 175 | # model.freeze_cnn_backbone() 176 | 177 | LOGGER.info("Moving model to device") 178 | model.to(device) 179 | LOGGER.info("Completed moving model to device.") 180 | 181 | LOGGER.info("Setup model done!") 182 | return model 183 | 184 | 185 | def forward_step(cfg, model, batch): 186 | """shared for training and validation""" 187 | # used to make visual feature copies 188 | if not cfg.use_itm: 189 | batch["itm_labels"] = None 190 | outputs = model(batch) # dict 191 | return outputs 192 | 193 | 194 | @torch.no_grad() 195 | def validate(model, val_loader, cfg): 196 | model.eval() 197 | 198 | n_itc_ex = 0 199 | n_t2i_corrects = 0 200 | n_i2t_corrects = 0 201 | 202 | itc_loss = 0 203 | st = time.time() 204 | val_log = {'valid/itc_loss': 0, 205 | 'valid/i2t_acc': 0, 206 | 'valid/t2i_acc': 0 207 | } 208 | 209 | debug_step = 5 210 | val_loaders = val_loader if isinstance(val_loader, dict) else { 211 | "unnamed_val_loader": val_loader} 212 | 213 | total_val_iters = 0 214 | 215 | LOGGER.info(f"In total {len(val_loaders)} val loaders") 216 | for loader_name, val_loader in val_loaders.items(): 217 | LOGGER.info(f"Loop val_loader {loader_name}.") 218 | 219 | total_val_iters += len(val_loader) 220 | for val_step, batch in enumerate(val_loader): 221 | # use iter to reset MetaLoader 222 | # forward pass 223 | outputs = forward_step(cfg, model, batch) 224 | 225 | assert not cfg.use_itm and not cfg.use_mlm 226 | 227 | if cfg.use_itc: 228 | itc_loss += outputs["itc_loss"].sum().item() 229 | 230 | if cfg.debug and val_step >= debug_step: 231 | break 232 | 233 | # Gather across all processes 234 | all_gather_itc_loss = all_gather_list(itc_loss) 235 | itc_loss = sum(all_gather_itc_loss) 236 | 237 | # FIXME check this whether take mean? 238 | assert cfg.use_itc, 'cfg.use_itc is False for contrastive-only pretraining.' 239 | 240 | val_log.update({ 241 | 'valid/itc_loss': float(itc_loss), 242 | }) 243 | 244 | n_itc_ex += len(outputs["itc_labels"]) 245 | n_t2i_corrects += ( 246 | outputs["t2i_scores"].max( 247 | dim=-1)[1] == outputs["itc_labels"]).sum().item() 248 | n_i2t_corrects += ( 249 | outputs["i2t_scores"].max( 250 | dim=-1)[1] == outputs["itc_labels"]).sum().item() 251 | 252 | n_i2t_corrects = sum(all_gather_list(n_i2t_corrects)) 253 | n_t2i_corrects = sum(all_gather_list(n_t2i_corrects)) 254 | 255 | n_itc_ex = sum(all_gather_list(n_itc_ex)) 256 | 257 | if n_itc_ex != 0: 258 | val_log.update({ 259 | 'valid/i2t_acc': float(n_i2t_corrects / n_itc_ex), 260 | 'valid/t2i_acc': float(n_t2i_corrects / n_itc_ex) 261 | }) 262 | 263 | TB_LOGGER.log_scalar_dict(val_log) 264 | LOGGER.info(f"validation finished in {int(time.time() - st)} seconds, ") 265 | 266 | LOGGER.info("[itc_loss]: {} ".format(itc_loss)) 267 | LOGGER.info("In total, {} validation iters.".format(total_val_iters)) 268 | 269 | model.train() 270 | return val_log 271 | 272 | 273 | def start_training(): 274 | cfg = shared_configs.get_sparse_pretraining_args() 275 | set_random_seed(cfg.seed) 276 | 277 | n_gpu = hvd.size() 278 | # device = torch.device("cuda", hvd.local_rank()) 279 | # torch.cuda.set_device(hvd.local_rank()) 280 | 281 | # This resolves the issue GPU 0 always has more processes running and more GPU-RAM. 282 | # c.f. https://github.com/horovod/horovod/issues/2625#issuecomment-868134876 283 | os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank()) 284 | device = torch.device("cuda", 0) 285 | torch.cuda.set_device(0) 286 | 287 | if hvd.rank() != 0: 288 | LOGGER.disabled = True 289 | LOGGER.info(f"device: {device} n_gpu: {n_gpu}, " 290 | f"rank: {hvd.rank()}, 16-bits training: {cfg.fp16}") 291 | 292 | model = setup_model(cfg, device=device) 293 | model.train() 294 | 295 | optimizer = setup_e2e_optimizer(model, cfg) 296 | 297 | # Horovod: (optional) compression algorithm.compressin 298 | compression = hvd.Compression.none 299 | optimizer = hvd.DistributedOptimizer( 300 | optimizer, named_parameters=model.named_parameters(), 301 | compression=compression) 302 | 303 | # Horovod: broadcast parameters & optimizer state. 304 | compression = hvd.Compression.none 305 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 306 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 307 | 308 | model, optimizer = amp.initialize( 309 | model, optimizer, enabled=cfg.fp16, opt_level='O2', 310 | keep_batchnorm_fp32=True) 311 | 312 | # prepare data 313 | tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir) 314 | train_loaders, val_loaders = setup_dataloaders(cfg, tokenizer) 315 | train_loader = MetaLoader(train_loaders, 316 | accum_steps=cfg.gradient_accumulation_steps, 317 | distributed=n_gpu > 1) 318 | img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std) 319 | train_loader = PrefetchLoader(train_loader, img_norm) 320 | val_loaders = {k: PrefetchLoader(v, img_norm) 321 | for k, v in val_loaders.items()} 322 | 323 | # compute the number of steps and update cfg 324 | total_train_batch_size = int( 325 | n_gpu * cfg.train_batch_size * 326 | cfg.gradient_accumulation_steps * cfg.max_n_example_per_group) 327 | total_n_epochs = cfg.num_train_epochs 328 | cfg.num_train_steps = int(math.ceil( 329 | 1. * train_loader.n_batches_in_epoch * total_n_epochs / 330 | (n_gpu * cfg.gradient_accumulation_steps))) 331 | cfg.valid_steps = int(math.ceil( 332 | 1. * cfg.num_train_steps / cfg.num_valid / 333 | cfg.min_valid_steps)) * cfg.min_valid_steps 334 | actual_num_valid = int(math.floor( 335 | 1. * cfg.num_train_steps / cfg.valid_steps)) + 1 336 | 337 | # restore 338 | restorer = TrainingRestorer(cfg, model, optimizer) 339 | global_step = restorer.global_step 340 | TB_LOGGER.global_step = global_step 341 | if hvd.rank() == 0: 342 | LOGGER.info("Saving training meta...") 343 | save_training_meta(cfg) 344 | LOGGER.info("Saving training done...") 345 | TB_LOGGER.create(join(cfg.output_dir, 'log')) 346 | pbar = tqdm(total=cfg.num_train_steps) 347 | model_saver = ModelSaver(join(cfg.output_dir, "ckpt")) 348 | add_log_to_file(join(cfg.output_dir, "log", "log.txt")) 349 | else: 350 | LOGGER.disabled = True 351 | pbar = NoOp() 352 | model_saver = NoOp() 353 | restorer = NoOp() 354 | 355 | if global_step > 0: 356 | pbar.update(global_step) 357 | 358 | LOGGER.info(cfg) 359 | LOGGER.info("Starting training...") 360 | LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") 361 | LOGGER.info(f" Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}") 362 | LOGGER.info(f" max_n_example_per_group = {cfg.max_n_example_per_group}") 363 | LOGGER.info(f" Accumulate steps = {cfg.gradient_accumulation_steps}") 364 | LOGGER.info(f" Total batch size = #GPUs * Single-GPU batch size * " 365 | f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}") 366 | LOGGER.info(f" Total #batches - single epoch = {train_loader.n_batches_in_epoch}.") 367 | LOGGER.info(f" Total #steps = {cfg.num_train_steps}") 368 | LOGGER.info(f" Total #epochs = {total_n_epochs}.") 369 | LOGGER.info(f" Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times") 370 | 371 | 372 | # quick hack for amp delay_unscale bug 373 | with optimizer.skip_synchronize(): 374 | optimizer.zero_grad() 375 | if global_step == 0: 376 | optimizer.step() 377 | debug_step = 5 378 | 379 | tasks = [] 380 | for name, flag in zip(["itc"], [cfg.use_itc]): 381 | if flag: 382 | tasks.append(name) 383 | task2loss = {t: RunningMeter(f'train_loss/{t}') 384 | for t in tasks} 385 | task2loss["loss"] = RunningMeter('train_loss/loss') 386 | 387 | train_log = {'train/i2t_acc': 0, 388 | 'train/t2i_acc': 0} 389 | 390 | for step, (task, batch) in enumerate(train_loader): 391 | # forward pass 392 | outputs = forward_step(cfg, model, batch) 393 | # mlm_loss, itm_loss, itc_loss, mpm_loss = 0, 0, 0, 0 394 | itc_loss = 0 395 | 396 | assert not cfg.use_mlm and not cfg.use_itm 397 | 398 | if cfg.use_itc: 399 | n_itc_ex = len(outputs["itc_labels"]) 400 | n_t2i_corrects = ( 401 | outputs["t2i_scores"].max( 402 | dim=-1)[1] == outputs["itc_labels"]).sum().item() 403 | n_i2t_corrects = ( 404 | outputs["i2t_scores"].max( 405 | dim=-1)[1] == outputs["itc_labels"]).sum().item() 406 | 407 | train_log.update({ 408 | 'train/t2i_acc': float(n_t2i_corrects / n_itc_ex), 409 | 'train/i2t_acc': float(n_i2t_corrects / n_itc_ex), 410 | # 'train/mpm_acc': mpm_acc 411 | }) 412 | 413 | itc_loss = outputs["itc_loss"] 414 | task2loss["itc"](itc_loss.item()) 415 | 416 | loss = itc_loss 417 | task2loss["loss"](loss.item()) 418 | 419 | delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0 420 | with amp.scale_loss( 421 | loss, optimizer, delay_unscale=delay_unscale 422 | ) as scaled_loss: 423 | scaled_loss.backward() 424 | zero_none_grad(model) 425 | optimizer.synchronize() 426 | 427 | # optimizer 428 | if (step + 1) % cfg.gradient_accumulation_steps == 0: 429 | global_step += 1 430 | if (step + 1) % cfg.log_interval == 0: 431 | TB_LOGGER.log_scalar_dict({l.name: l.val 432 | for l in task2loss.values() 433 | if l.val is not None}) 434 | n_epoch = int(1. * n_gpu * cfg.gradient_accumulation_steps * 435 | global_step / train_loader.n_batches_in_epoch) 436 | 437 | # learning rate scheduling for the whole model 438 | lr_this_step = get_lr_sched( 439 | global_step, cfg.decay, cfg.learning_rate, 440 | cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio, 441 | decay_epochs=cfg.step_decay_epochs, 442 | multi_step_epoch=n_epoch) 443 | 444 | # Hardcoded param group length 445 | # assert len(optimizer.param_groups) == 8 446 | for pg_n, param_group in enumerate( 447 | optimizer.param_groups): 448 | param_group['lr'] = lr_this_step 449 | 450 | if (step + 1) % cfg.log_interval == 0: 451 | TB_LOGGER.add_scalar( 452 | "train/lr", lr_this_step, global_step) 453 | 454 | # update model params 455 | if cfg.grad_norm != -1: 456 | # import pdb; pdb.set_trace() 457 | grad_norm = clip_grad_norm_( 458 | amp.master_params(optimizer), cfg.grad_norm) 459 | if (step + 1) % cfg.log_interval == 0: 460 | TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step) 461 | TB_LOGGER.step() 462 | 463 | # Check if there is None grad 464 | none_grads = [ 465 | p[0] for p in model.named_parameters() 466 | if p[1].requires_grad and p[1].grad is None] 467 | 468 | assert len(none_grads) == 0, f"{none_grads}" 469 | 470 | with optimizer.skip_synchronize(): 471 | optimizer.step() 472 | optimizer.zero_grad() 473 | restorer.step() 474 | pbar.update(1) 475 | 476 | # validate and checkpoint 477 | if global_step % cfg.valid_steps == 0: 478 | LOGGER.info(f'Step {global_step}: start validation') 479 | validate(model, val_loaders, cfg) 480 | model_saver.save(step=global_step, model=model) 481 | if global_step >= cfg.num_train_steps: 482 | break 483 | 484 | if cfg.debug and global_step >= debug_step: 485 | break 486 | 487 | if global_step % cfg.valid_steps != 0: 488 | LOGGER.info(f'Step {global_step}: start validation') 489 | validate(model, val_loaders, cfg) 490 | model_saver.save(step=global_step, model=model) 491 | 492 | 493 | if __name__ == '__main__': 494 | # Initialize Horovod 495 | hvd.init() 496 | start_training() 497 | -------------------------------------------------------------------------------- /src/utils/basic_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson as json 3 | import zipfile 4 | import numpy as np 5 | import pickle 6 | 7 | import pandas as pd 8 | 9 | 10 | def load_pickle(filename): 11 | with open(filename, "rb") as f: 12 | return pickle.load(f) 13 | 14 | 15 | def save_pickle(data, filename): 16 | with open(filename, "wb") as f: 17 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 18 | 19 | 20 | def load_json(filename): 21 | with open(filename, "r") as f: 22 | return json.load(f) 23 | 24 | 25 | def save_json(data, filename, save_pretty=False, sort_keys=False): 26 | with open(filename, "w") as f: 27 | if save_pretty: 28 | f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) 29 | else: 30 | json.dump(data, f) 31 | 32 | 33 | def load_jsonl(filename): 34 | with open(filename, "r") as f: 35 | return [json.loads(l.strip("\n")) for l in f.readlines()] 36 | 37 | 38 | def save_jsonl(data, filename): 39 | """data is a list""" 40 | with open(filename, "w") as f: 41 | f.write("\n".join([json.dumps(e) for e in data])) 42 | 43 | 44 | def concat_json_list(filepaths, save_path): 45 | json_lists = [] 46 | for p in filepaths: 47 | json_lists += load_json(p) 48 | save_json(json_lists, save_path) 49 | 50 | 51 | def save_lines(list_of_str, filepath): 52 | with open(filepath, "w") as f: 53 | f.write("\n".join(list_of_str)) 54 | 55 | 56 | def read_lines(filepath): 57 | with open(filepath, "r") as f: 58 | return [e.strip("\n") for e in f.readlines()] 59 | 60 | 61 | def mkdirp(p): 62 | if not os.path.exists(p): 63 | os.makedirs(p) 64 | 65 | 66 | def flat_list_of_lists(l): 67 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" 68 | return [item for sublist in l for item in sublist] 69 | 70 | 71 | def convert_to_seconds(hms_time): 72 | """ convert '00:01:12' to 72 seconds. 73 | :hms_time (str): time in comma separated string, e.g. '00:01:12' 74 | :return (int): time in seconds, e.g. 72 75 | """ 76 | times = [float(t) for t in hms_time.split(":")] 77 | return times[0] * 3600 + times[1] * 60 + times[2] 78 | 79 | 80 | def get_video_name_from_url(url): 81 | return url.split("/")[-1][:-4] 82 | 83 | 84 | def merge_dicts(list_dicts): 85 | merged_dict = list_dicts[0].copy() 86 | for i in range(1, len(list_dicts)): 87 | merged_dict.update(list_dicts[i]) 88 | return merged_dict 89 | 90 | 91 | def l2_normalize_np_array(np_array, eps=1e-5): 92 | """np_array: np.ndarray, (*, D), where the last dim will be normalized""" 93 | return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps) 94 | 95 | 96 | def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None, 97 | exclude_dirs_substring=None): 98 | """make a zip file of root_dir, save it to save_path. 99 | exclude_paths will be excluded if it is a subdir of root_dir. 100 | An enclosing_dir is added is specified. 101 | """ 102 | abs_src = os.path.abspath(src_dir) 103 | with zipfile.ZipFile(save_path, "w") as zf: 104 | for dirname, subdirs, files in os.walk(src_dir): 105 | if exclude_dirs is not None: 106 | for e_p in exclude_dirs: 107 | if e_p in subdirs: 108 | subdirs.remove(e_p) 109 | if exclude_dirs_substring is not None: 110 | to_rm = [] 111 | for d in subdirs: 112 | if exclude_dirs_substring in d: 113 | to_rm.append(d) 114 | for e in to_rm: 115 | subdirs.remove(e) 116 | arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:]) 117 | zf.write(dirname, arcname) 118 | for filename in files: 119 | if exclude_extensions is not None: 120 | if os.path.splitext(filename)[1] in exclude_extensions: 121 | continue # do not zip it 122 | absname = os.path.join(dirname, filename) 123 | arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:]) 124 | zf.write(absname, arcname) 125 | 126 | 127 | class AverageMeter(object): 128 | """Computes and stores the average and current/max/min value""" 129 | def __init__(self): 130 | self.val = 0 131 | self.avg = 0 132 | self.sum = 0 133 | self.count = 0 134 | self.max = -1e10 135 | self.min = 1e10 136 | self.reset() 137 | 138 | def reset(self): 139 | self.val = 0 140 | self.avg = 0 141 | self.sum = 0 142 | self.count = 0 143 | self.max = -1e10 144 | self.min = 1e10 145 | 146 | def update(self, val, n=1): 147 | self.max = max(val, self.max) 148 | self.min = min(val, self.min) 149 | self.val = val 150 | self.sum += val * n 151 | self.count += n 152 | self.avg = self.sum / self.count 153 | 154 | 155 | def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True): 156 | """Dissect an array (N, D) into a list a sub-array, 157 | np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept""" 158 | if assert_equal: 159 | assert len(np_array) == sum(lengths) 160 | length_indices = [0, ] 161 | for i in range(len(lengths)): 162 | length_indices.append(length_indices[i] + lengths[i]) 163 | if dim == 0: 164 | array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))] 165 | elif dim == 1: 166 | array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))] 167 | elif dim == 2: 168 | array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))] 169 | else: 170 | raise NotImplementedError 171 | return array_list 172 | 173 | 174 | def get_ratio_from_counter(counter_obj, threshold=200): 175 | keys = counter_obj.keys() 176 | values = counter_obj.values() 177 | filtered_values = [counter_obj[k] for k in keys if k > threshold] 178 | return float(sum(filtered_values)) / sum(values) 179 | 180 | 181 | def get_rounded_percentage(float_number, n_floats=2): 182 | return round(float_number * 100, n_floats) 183 | 184 | 185 | def read_dataframe(pkl_path): 186 | return pd.read_pickle(pkl_path) 187 | 188 | 189 | def save_frames_grid(img_array, out_path): 190 | import torch 191 | from torchvision.utils import make_grid 192 | from PIL import Image 193 | 194 | if len(img_array.shape) == 3: 195 | img_array = img_array.unsqueeze(0) 196 | elif len(img_array.shape) == 5: 197 | b, t, c, h, w = img_array.shape 198 | img_array = img_array.view(-1, c, h, w) 199 | elif len(img_array.shape) == 4: 200 | pass 201 | else: 202 | raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.') 203 | 204 | assert img_array.shape[1] == 3, "Exepcting input shape of (3, H, W), i.e. RGB-only." 205 | 206 | grid = make_grid(img_array) 207 | ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy() 208 | 209 | img = Image.fromarray(ndarr) 210 | 211 | img.save(out_path) 212 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | distributed API using Horovod 5 | Modified from OpenNMT's native pytorch distributed utils 6 | (https://github.com/OpenNMT/OpenNMT-py) 7 | """ 8 | 9 | import math 10 | import pickle 11 | 12 | import torch 13 | from horovod import torch as hvd 14 | from horovod.torch.mpi_ops import rank, size 15 | 16 | 17 | def all_reduce_and_rescale_tensors(tensors, rescale_denom): 18 | """All-reduce and rescale tensors at once (as a flattened tensor) 19 | Args: 20 | tensors: list of Tensors to all-reduce 21 | rescale_denom: denominator for rescaling summed Tensors 22 | """ 23 | # buffer size in bytes, determine equiv. # of elements based on data type 24 | sz = sum(t.numel() for t in tensors) 25 | buffer_t = tensors[0].new(sz).zero_() 26 | 27 | # copy tensors into buffer_t 28 | offset = 0 29 | for t in tensors: 30 | numel = t.numel() 31 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 32 | offset += numel 33 | 34 | # all-reduce and rescale 35 | hvd.allreduce_(buffer_t[:offset]) 36 | buffer_t.div_(rescale_denom) 37 | 38 | # copy all-reduced buffer back into tensors 39 | offset = 0 40 | for t in tensors: 41 | numel = t.numel() 42 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 43 | offset += numel 44 | 45 | 46 | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom, 47 | buffer_size=10485760): 48 | """All-reduce and rescale tensors in chunks of the specified size. 49 | Args: 50 | tensors: list of Tensors to all-reduce 51 | rescale_denom: denominator for rescaling summed Tensors 52 | buffer_size: all-reduce chunk size in bytes 53 | """ 54 | # buffer size in bytes, determine equiv. # of elements based on data type 55 | buffer_t = tensors[0].new( 56 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 57 | buffer = [] 58 | 59 | def all_reduce_buffer(): 60 | # copy tensors into buffer_t 61 | offset = 0 62 | for t in buffer: 63 | numel = t.numel() 64 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 65 | offset += numel 66 | 67 | # all-reduce and rescale 68 | hvd.allreduce_(buffer_t[:offset]) 69 | buffer_t.div_(rescale_denom) 70 | 71 | # copy all-reduced buffer back into tensors 72 | offset = 0 73 | for t in buffer: 74 | numel = t.numel() 75 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 76 | offset += numel 77 | 78 | filled = 0 79 | for t in tensors: 80 | sz = t.numel() * t.element_size() 81 | if sz > buffer_size: 82 | # tensor is bigger than buffer, all-reduce and rescale directly 83 | hvd.allreduce_(t) 84 | t.div_(rescale_denom) 85 | elif filled + sz > buffer_size: 86 | # buffer is full, all-reduce and replace buffer with grad 87 | all_reduce_buffer() 88 | buffer = [t] 89 | filled = sz 90 | else: 91 | # add tensor to buffer 92 | buffer.append(t) 93 | filled += sz 94 | 95 | if len(buffer) > 0: 96 | all_reduce_buffer() 97 | 98 | 99 | def broadcast_tensors(tensors, root_rank, buffer_size=10485760): 100 | """broadcast tensors in chunks of the specified size. 101 | Args: 102 | tensors: list of Tensors to broadcast 103 | root_rank: rank to broadcast 104 | buffer_size: all-reduce chunk size in bytes 105 | """ 106 | # buffer size in bytes, determine equiv. # of elements based on data type 107 | buffer_t = tensors[0].new( 108 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 109 | buffer = [] 110 | 111 | def broadcast_buffer(): 112 | # copy tensors into buffer_t 113 | offset = 0 114 | for t in buffer: 115 | numel = t.numel() 116 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 117 | offset += numel 118 | 119 | # broadcast 120 | hvd.broadcast_(buffer_t[:offset], root_rank) 121 | 122 | # copy all-reduced buffer back into tensors 123 | offset = 0 124 | for t in buffer: 125 | numel = t.numel() 126 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 127 | offset += numel 128 | 129 | filled = 0 130 | for t in tensors: 131 | sz = t.numel() * t.element_size() 132 | if sz > buffer_size: 133 | # tensor is bigger than buffer, broadcast directly 134 | hvd.broadcast_(t, root_rank) 135 | elif filled + sz > buffer_size: 136 | # buffer is full, broadcast and replace buffer with tensor 137 | broadcast_buffer() 138 | buffer = [t] 139 | filled = sz 140 | else: 141 | # add tensor to buffer 142 | buffer.append(t) 143 | filled += sz 144 | 145 | if len(buffer) > 0: 146 | broadcast_buffer() 147 | 148 | 149 | def all_gather_list(data, max_size=4096): 150 | """Gathers arbitrary data from all nodes into a list.""" 151 | world_size = hvd.size() 152 | if not hasattr(all_gather_list, '_in_buffer') or \ 153 | max_size != all_gather_list._in_buffer.size(): 154 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 155 | in_buffer = all_gather_list._in_buffer 156 | 157 | enc = pickle.dumps(data) 158 | enc_size = len(enc) 159 | if enc_size + 2 > max_size: 160 | raise ValueError( 161 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 162 | assert max_size < 255*256 163 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 164 | in_buffer[1] = enc_size % 255 165 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 166 | 167 | # FIXME cannot create buffer 168 | out = hvd.allgather(in_buffer.cuda()) 169 | 170 | results = [] 171 | for i in range(0, max_size*world_size, max_size): 172 | out_buffer = out[i:i+max_size] 173 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 174 | 175 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 176 | result = pickle.loads(bytes_list) 177 | results.append(result) 178 | return results 179 | 180 | 181 | def any_broadcast(data, root_rank, max_size=4096): 182 | """broadcast arbitrary data from root_rank to all nodes.""" 183 | if not hasattr(any_broadcast, '_in_buffer') or \ 184 | max_size != any_broadcast._in_buffer.size(): 185 | any_broadcast._buffer = torch.cuda.ByteTensor(max_size) 186 | buffer_ = any_broadcast._buffer 187 | 188 | enc = pickle.dumps(data) 189 | enc_size = len(enc) 190 | if enc_size + 2 > max_size: 191 | raise ValueError( 192 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 193 | assert max_size < 255*256 194 | buffer_[0] = enc_size // 255 # this encoding works for max_size < 65k 195 | buffer_[1] = enc_size % 255 196 | buffer_[2:enc_size+2] = torch.ByteTensor(list(enc)) 197 | 198 | hvd.broadcast_(buffer_, root_rank) 199 | 200 | size = (255 * buffer_[0].item()) + buffer_[1].item() 201 | 202 | bytes_list = bytes(buffer_[2:size+2].tolist()) 203 | result = pickle.loads(bytes_list) 204 | return result 205 | 206 | def allgather_object(obj, name=None): 207 | """ 208 | Serializes and allgathers an object from all other processes. 209 | 210 | Arguments: 211 | obj: An object capable of being serialized without losing any context. 212 | name: Optional name to use during allgather, will default to the class 213 | type. 214 | 215 | Returns: 216 | The list of objects that were allgathered across all ranks. 217 | """ 218 | import io 219 | import cloudpickle 220 | 221 | if name is None: 222 | name = type(obj).__name__ 223 | 224 | def load(byte_array): 225 | buf = io.BytesIO(byte_array.tobytes()) 226 | return cloudpickle.load(buf) 227 | 228 | b = io.BytesIO() 229 | cloudpickle.dump(obj, b) 230 | 231 | t = torch.ByteTensor(bytearray(b.getvalue())) 232 | sz = torch.IntTensor([t.shape[0]]) 233 | 234 | sizes = hvd.allgather(sz, name=name + '.sz').numpy() 235 | gathered = hvd.allgather(t, name=name + '.t').numpy() 236 | 237 | def select(i): 238 | start = sum(sizes[:i]) 239 | end = start + sizes[i] 240 | return gathered[start:end] 241 | 242 | return [load(select(i)) for i in range(size())] -------------------------------------------------------------------------------- /src/utils/grad_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | 5 | def detach_variable(inputs): 6 | if isinstance(inputs, tuple): 7 | out = [] 8 | for inp in inputs: 9 | x = inp.detach() 10 | x.requires_grad = inp.requires_grad 11 | out.append(x) 12 | return tuple(out) 13 | else: 14 | raise RuntimeError( 15 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 16 | 17 | 18 | def check_backward_validity(inputs): 19 | if not any(inp.requires_grad for inp in inputs): 20 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") 21 | 22 | 23 | class CheckpointFunction(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, run_function, length, *args): 26 | ctx.run_function = run_function 27 | ctx.input_tensors = list(args[:length]) 28 | ctx.input_params = list(args[length:]) 29 | with torch.no_grad(): 30 | output_tensors = ctx.run_function(*ctx.input_tensors) 31 | return output_tensors 32 | 33 | @staticmethod 34 | def backward(ctx, *output_grads): 35 | for i in range(len(ctx.input_tensors)): 36 | temp = ctx.input_tensors[i] 37 | ctx.input_tensors[i] = temp.detach() 38 | ctx.input_tensors[i].requires_grad = temp.requires_grad 39 | with torch.enable_grad(): 40 | output_tensors = ctx.run_function(*ctx.input_tensors) 41 | input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) 42 | return (None, None) + input_grads -------------------------------------------------------------------------------- /src/utils/load_save.py: -------------------------------------------------------------------------------- 1 | """ 2 | saving utilities 3 | """ 4 | import json 5 | import os 6 | from os.path import dirname, exists, join, realpath 7 | import subprocess 8 | from apex import amp 9 | from easydict import EasyDict as edict 10 | 11 | import torch 12 | from src.utils.basic_utils import save_json, make_zipfile, load_json 13 | from src.utils.logger import LOGGER 14 | from typing import Any, Dict, Union 15 | 16 | from src.modeling.timesformer.helpers import resize_spatial_embedding, resize_temporal_embedding 17 | 18 | 19 | def save_training_meta(args): 20 | # args is an EasyDict object, treat it the same as a normal dict 21 | os.makedirs(join(args.output_dir, 'log'), exist_ok=True) 22 | os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True) 23 | 24 | # training args 25 | save_args_path = join(args.output_dir, 'log', 'args.json') 26 | save_json(args, save_args_path, save_pretty=True) 27 | 28 | # model args 29 | model_config = json.load(open(args.model_config)) 30 | save_model_config_path = join(args.output_dir, 'log', 'model_config.json') 31 | save_json(model_config, save_model_config_path, save_pretty=True) 32 | 33 | # save a copy of the codebase. !!!Do not store heavy file in your codebase when using it. 34 | code_dir = dirname(dirname(dirname(os.path.realpath(__file__)))) 35 | code_zip_filename = os.path.join(args.output_dir, "code.zip") 36 | LOGGER.info(f"Saving code from {code_dir} to {code_zip_filename}...") 37 | make_zipfile(code_dir, code_zip_filename, 38 | enclosing_dir="code", 39 | exclude_dirs_substring="results", 40 | exclude_dirs=["__pycache__", "output", "data", "ext"], 41 | exclude_extensions=[".pyc", ".ipynb", ".swap", ".pt"]) 42 | LOGGER.info(f"Saving code done.") 43 | 44 | 45 | class ModelSaver(object): 46 | def __init__(self, output_dir): 47 | self.output_dir = output_dir 48 | self.max_save_load_trial = 10 49 | 50 | def save(self, step, model, optimizer=None, prefix="model"): 51 | model_path = join(self.output_dir, f"{prefix}_step_{step}.pt") 52 | state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 53 | for k, v in model.state_dict().items()} 54 | # with retrial, as azure blob fails occasionally. 55 | save_trial = 0 56 | while save_trial < self.max_save_load_trial: 57 | try: 58 | LOGGER.info(f"ModelSaver save trial NO. {save_trial}") 59 | torch.save(state_dict, model_path) 60 | if optimizer is not None: 61 | optimizer_state_dict = \ 62 | {k: v.cpu() if isinstance(v, torch.Tensor) else v 63 | for k, v in optimizer.state_dict().items()} 64 | dump = {'step': step, 'optimizer': optimizer_state_dict} 65 | torch.save( 66 | dump, 67 | f'{self.output_dir}/{prefix}_step_{step}_train_state.pt') 68 | break 69 | except Exception as e: 70 | save_trial += 1 71 | 72 | 73 | def load_state_dict_with_pos_embed_resizing(model, loaded_state_dict_or_path, 74 | num_patches, num_frames, 75 | spatial_embed_key='visual_encoder.model.pos_embed', 76 | temporal_embed_key='visual_encoder.model.time_embed', 77 | strict=False, 78 | remove_text_encoder_prefix=False 79 | ): 80 | """operated in-place, no need to return `model`, 81 | 82 | Used to load e2e model checkpoints. 83 | 84 | remove_text_encoder_prefix: set to True, when finetune downstream models from pre-trained checkpoints. 85 | """ 86 | 87 | if isinstance(loaded_state_dict_or_path, str): 88 | loaded_state_dict = torch.load( 89 | loaded_state_dict_or_path, map_location="cpu") 90 | 91 | else: 92 | loaded_state_dict = loaded_state_dict_or_path 93 | 94 | new_state_dict = loaded_state_dict.copy() 95 | 96 | for key in loaded_state_dict: 97 | if 'text_encoder.bert' in key and remove_text_encoder_prefix: 98 | new_key = key.replace('text_encoder.bert','text_encoder') 99 | new_state_dict[new_key] = new_state_dict.pop(key) 100 | 101 | loaded_state_dict = new_state_dict 102 | 103 | ## Resizing spatial embeddings in case they don't match 104 | if num_patches + 1 != loaded_state_dict[spatial_embed_key].size(1): 105 | loaded_state_dict[spatial_embed_key] = resize_spatial_embedding(loaded_state_dict, spatial_embed_key, num_patches) 106 | else: 107 | LOGGER.info('The length of spatial position embedding matches. No need to resize.') 108 | 109 | ## Resizing time embeddings in case they don't match 110 | if temporal_embed_key in loaded_state_dict and num_frames != loaded_state_dict[temporal_embed_key].size(1): 111 | loaded_state_dict[temporal_embed_key] = resize_temporal_embedding(loaded_state_dict, temporal_embed_key, num_frames) 112 | else: 113 | LOGGER.info('No temporal encoding found. Or the length of temporal position embedding matches. No need to resize.') 114 | 115 | model_keys = set([k for k in list(model.state_dict().keys())]) 116 | load_keys = set(loaded_state_dict.keys()) 117 | 118 | toload = {} 119 | mismatched_shape_keys = [] 120 | for k in model_keys: 121 | if k in load_keys: 122 | if model.state_dict()[k].shape != loaded_state_dict[k].shape: 123 | mismatched_shape_keys.append(k) 124 | else: 125 | toload[k] = loaded_state_dict[k] 126 | 127 | LOGGER.info("You can ignore the keys with `num_batches_tracked` or from task heads") 128 | LOGGER.info("Keys in loaded but not in model:") 129 | diff_keys = load_keys.difference(model_keys) 130 | LOGGER.info(f"In total {len(diff_keys)}, {sorted(diff_keys)}") 131 | LOGGER.info("Keys in model but not in loaded:") 132 | diff_keys = model_keys.difference(load_keys) 133 | LOGGER.info(f"In total {len(diff_keys)}, {sorted(diff_keys)}") 134 | LOGGER.info("Keys in model and loaded, but shape mismatched:") 135 | LOGGER.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}") 136 | model.load_state_dict(toload, strict=strict) 137 | 138 | def compare_dict_difference(dict1, dict2, dict1_name="dict1", 139 | dict2_name="dict2", 140 | print_value_diff=True, verbose=False): 141 | """ 142 | Args: 143 | dict1: 144 | dict2: 145 | dict1_name: 146 | dict2_name: 147 | print_value_diff: bool, output dict value difference within shared keys 148 | for dict1 and dict2. In effect only when verbose == True 149 | verbose: 150 | """ 151 | keys1 = set(dict1.keys()) 152 | keys2 = set(dict2.keys()) 153 | shared_keys = keys1.intersection(keys2) 154 | keys1_unique = keys1.difference(shared_keys) 155 | keys2_unique = keys2.difference(shared_keys) 156 | key_diff_list = list(keys1_unique) + list(keys2_unique) 157 | 158 | # value difference in the shared keys in dict1 and dict2 159 | value_diff_dict = {} 160 | for k in shared_keys: 161 | if dict1[k] != dict2[k]: 162 | value_diff_dict[k] = [(dict1_name, dict1[k]), (dict2_name, dict2[k])] 163 | 164 | if verbose: 165 | LOGGER.info("=" * 30 + "key difference") 166 | LOGGER.info(f"keys in {dict1_name} but not in {dict2_name}: " 167 | f"total {len(keys1_unique)}, {sorted(keys1_unique)}") 168 | LOGGER.info(f"keys in {dict2_name} but not in {dict1_name}: " 169 | f"total {len(keys2_unique)}, {sorted(keys2_unique)}") 170 | 171 | if verbose and print_value_diff: 172 | 173 | LOGGER.info("=" * 30 + "value difference") 174 | LOGGER.info(f"{json.dumps(value_diff_dict, indent=4)}") 175 | 176 | return value_diff_dict, key_diff_list 177 | 178 | 179 | def _to_cuda(state): 180 | """ usually load from cpu checkpoint but need to load to cuda """ 181 | if isinstance(state, torch.Tensor): 182 | ret = state.cuda() # assume propoerly set py torch.cuda.set_device 183 | if 'Half' in state.type(): 184 | ret = ret.float() # apex O2 requires it 185 | return ret 186 | elif isinstance(state, list): 187 | new_state = [_to_cuda(t) for t in state] 188 | elif isinstance(state, tuple): 189 | new_state = tuple(_to_cuda(t) for t in state) 190 | elif isinstance(state, dict): 191 | new_state = {n: _to_cuda(t) for n, t in state.items()} 192 | else: 193 | return state 194 | return new_state 195 | 196 | 197 | def _to_cpu(state): 198 | """ store in cpu to avoid GPU0 device, fp16 to save space """ 199 | if isinstance(state, torch.Tensor): 200 | ret = state.cpu() 201 | if 'Float' in state.type(): 202 | ret = ret.half() 203 | return ret 204 | elif isinstance(state, list): 205 | new_state = [_to_cpu(t) for t in state] 206 | elif isinstance(state, tuple): 207 | new_state = tuple(_to_cpu(t) for t in state) 208 | elif isinstance(state, dict): 209 | new_state = {n: _to_cpu(t) for n, t in state.items()} 210 | else: 211 | return state 212 | return new_state 213 | 214 | 215 | class TrainingRestorer(object): 216 | """ckpt_dict: a dict contains all optimizers/models""" 217 | def __init__(self, opts, **ckpt_dict): 218 | if exists(opts.output_dir): 219 | restore_opts = json.load(open( 220 | f'{opts.output_dir}/log/args.json', 'r')) 221 | assert opts == edict(restore_opts) 222 | # keep 2 checkpoints in case of corrupted 223 | self.save_path = f'{opts.output_dir}/restore.pt' 224 | self.backup_path = f'{opts.output_dir}/restore_backup.pt' 225 | self.ckpt_dict = ckpt_dict 226 | self.save_steps = opts.save_steps 227 | self.amp = opts.fp16 228 | # since saving to or loading from azure blob fails sometimes 229 | self.max_save_load_trial = 10 230 | if exists(self.save_path) or exists(self.backup_path): 231 | LOGGER.info('found previous checkpoint. try to resume...') 232 | # with retrial, as azure blob fails occasionally. 233 | restore_trial = 0 234 | while restore_trial < self.max_save_load_trial: 235 | LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}") 236 | try: 237 | self.restore() 238 | break 239 | except Exception as e: 240 | restore_trial += 1 241 | else: 242 | self.global_step = 0 243 | 244 | def step(self): 245 | self.global_step += 1 246 | if self.global_step % self.save_steps == 0: 247 | # with retrial, as azure blob fails occasionally. 248 | save_trial = 0 249 | while save_trial < self.max_save_load_trial: 250 | LOGGER.info(f"TrainingRestorer save trial NO. {save_trial}") 251 | try: 252 | self.save() 253 | break 254 | except Exception as e: 255 | save_trial += 1 256 | 257 | def save(self): 258 | checkpoint_to_save = {'global_step': self.global_step} 259 | for k in self.ckpt_dict: 260 | checkpoint_to_save[k] = _to_cpu(self.ckpt_dict[k].state_dict()) 261 | if self.amp: 262 | checkpoint_to_save['amp_state_dict'] = amp.state_dict() 263 | if exists(self.save_path): 264 | os.rename(self.save_path, self.backup_path) 265 | torch.save(checkpoint_to_save, self.save_path) 266 | 267 | def restore(self): 268 | try: 269 | checkpoint = torch.load(self.save_path) 270 | except Exception: 271 | checkpoint = torch.load(self.backup_path) 272 | self.global_step = checkpoint['global_step'] 273 | for k in self.ckpt_dict: 274 | self.ckpt_dict[k].load_state_dict(_to_cuda(checkpoint[k])) 275 | if self.amp: 276 | amp.load_state_dict(checkpoint['amp_state_dict']) 277 | LOGGER.info(f'resume training from step {self.global_step}') 278 | 279 | 280 | class E2E_TrainingRestorer(object): 281 | def __init__(self, opts, model, optimizer): 282 | if exists(f"{opts.output_dir}/log/args.json"): 283 | restore_opts = json.load( 284 | open(f'{opts.output_dir}/log/args.json', 'r')) 285 | with open(join( 286 | opts.output_dir, 'log', 287 | 'restore_args.json'), 'w') as writer: 288 | json.dump(vars(opts), writer, indent=4) 289 | # assert opts == edict(restore_opts) 290 | # keep 2 checkpoints in case of corrupted 291 | self.save_path = f'{opts.output_dir}/restore.pt' 292 | self.backup_path = f'{opts.output_dir}/restore_backup.pt' 293 | self.model = model 294 | self.optimizer = optimizer 295 | self.save_steps = int(opts.save_steps_ratio * opts.num_train_steps) 296 | self.amp = opts.fp16 297 | # since saving to or loading from azure blob fails sometimes 298 | self.max_save_load_trial = 10 299 | if exists(self.save_path) or exists(self.backup_path): 300 | LOGGER.info('found previous checkpoint. try to resume...') 301 | # with retrial, as azure blob fails occasionally. 302 | restore_trial = 0 303 | while restore_trial < self.max_save_load_trial: 304 | LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}") 305 | try: 306 | self.restore(opts) 307 | break 308 | except Exception as e: 309 | restore_trial += 1 310 | else: 311 | self.global_step = 0 312 | 313 | def step(self): 314 | self.global_step += 1 315 | if self.global_step % self.save_steps == 0: 316 | # with retrial, as azure blob fails occasionally. 317 | save_trial = 0 318 | while save_trial < self.max_save_load_trial: 319 | LOGGER.info(f"TrainingRestorer save trial NO. {save_trial}") 320 | try: 321 | self.save() 322 | break 323 | except Exception as e: 324 | save_trial += 1 325 | 326 | def save(self): 327 | checkpoint = {'global_step': self.global_step, 328 | 'model_state_dict': _to_cpu(self.model.state_dict()), 329 | 'optim_state_dict': _to_cpu(self.optimizer.state_dict())} 330 | if self.amp: 331 | checkpoint['amp_state_dict'] = amp.state_dict() 332 | if exists(self.save_path): 333 | os.rename(self.save_path, self.backup_path) 334 | torch.save(checkpoint, self.save_path) 335 | 336 | def restore(self, opts): 337 | try: 338 | checkpoint = torch.load(self.save_path) 339 | except Exception: 340 | checkpoint = torch.load(self.backup_path) 341 | self.global_step = checkpoint['global_step'] 342 | self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict'])) 343 | self.optimizer.load_state_dict( 344 | _to_cuda(checkpoint['optim_state_dict'])) 345 | if self.amp: 346 | amp.load_state_dict(checkpoint['amp_state_dict']) 347 | LOGGER.info(f'resume training from step {self.global_step}') 348 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | references: UNITER 3 | """ 4 | 5 | import logging 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 10 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 11 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 12 | LOGGER = logging.getLogger('__main__') # this is the global logger 13 | 14 | 15 | def add_log_to_file(log_path): 16 | fh = logging.FileHandler(log_path) 17 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 18 | fh.setFormatter(formatter) 19 | LOGGER.addHandler(fh) 20 | 21 | 22 | class TensorboardLogger(object): 23 | def __init__(self): 24 | self._logger = None 25 | self._global_step = 0 26 | 27 | def create(self, path): 28 | self._logger = SummaryWriter(path) 29 | 30 | def noop(self, *args, **kwargs): 31 | return 32 | 33 | def step(self): 34 | self._global_step += 1 35 | 36 | @property 37 | def global_step(self): 38 | return self._global_step 39 | 40 | @global_step.setter 41 | def global_step(self, step): 42 | self._global_step = step 43 | 44 | def log_scalar_dict(self, log_dict, prefix=''): 45 | """ log a dictionary of scalar values""" 46 | if self._logger is None: 47 | return 48 | if prefix: 49 | prefix = f'{prefix}_' 50 | for name, value in log_dict.items(): 51 | if isinstance(value, dict): 52 | self.log_scalar_dict(value, self._global_step, 53 | prefix=f'{prefix}{name}') 54 | else: 55 | self._logger.add_scalar(f'{prefix}{name}', value, 56 | self._global_step) 57 | 58 | def __getattr__(self, name): 59 | if self._logger is None: 60 | return self.noop 61 | return self._logger.__getattribute__(name) 62 | 63 | 64 | TB_LOGGER = TensorboardLogger() 65 | 66 | 67 | class RunningMeter(object): 68 | """ running meteor of a scalar value 69 | (useful for monitoring training loss) 70 | """ 71 | def __init__(self, name, val=None, smooth=0.99): 72 | self._name = name 73 | self._sm = smooth 74 | self._val = val 75 | 76 | def __call__(self, value): 77 | self._val = (value if self._val is None 78 | else value*(1-self._sm) + self._val*self._sm) 79 | 80 | def __str__(self): 81 | return f'{self._name}: {self._val:.4f}' 82 | 83 | @property 84 | def val(self): 85 | return self._val 86 | 87 | @property 88 | def name(self): 89 | return self._name 90 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from UNITER 3 | """ 4 | import json 5 | import random 6 | import sys 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class NoOp(object): 13 | """ useful for distributed training No-Ops """ 14 | def __getattr__(self, name): 15 | return self.noop 16 | 17 | def noop(self, *args, **kwargs): 18 | return 19 | 20 | 21 | def set_random_seed(seed): 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | 27 | 28 | def zero_none_grad(model): 29 | for p in model.parameters(): 30 | if p.grad is None and p.requires_grad: 31 | p.grad = p.data.new(p.size()).zero_() 32 | --------------------------------------------------------------------------------