├── .gitignore ├── LICENSE ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── configs ├── anet │ ├── anet_few_shot.yaml │ ├── anet_k400_finetune.yaml │ ├── anet_k400_finetune_336.yaml │ └── anet_zero_shot.yaml ├── hmdb51 │ ├── hmdb_few_shot.yaml │ ├── hmdb_k400_finetune.yaml │ ├── hmdb_k400_finetune_336.yaml │ └── hmdb_zero_shot.yaml ├── k400 │ ├── k400_few_shot.yaml │ ├── k400_train_rgb_rn50.yaml │ ├── k400_train_rgb_vitb-16-f16.yaml │ ├── k400_train_rgb_vitb-16-f8.yaml │ ├── k400_train_rgb_vitb-32-f16.yaml │ ├── k400_train_rgb_vitb-32-f8.yaml │ ├── k400_train_rgb_vitl-14-336-f32.yaml │ ├── k400_train_rgb_vitl-14-336-f8.yaml │ └── k400_train_rgb_vitl-14-f8.yaml ├── k600 │ ├── k600_zero_shot_split1.yaml │ ├── k600_zero_shot_split2.yaml │ └── k600_zero_shot_split3.yaml └── ucf101 │ ├── ucf_few_shot.yaml │ ├── ucf_k400_finetune.yaml │ ├── ucf_k400_finetune_336.yaml │ └── ucf_zero_shot.yaml ├── datasets ├── __init__.py ├── datasets.py └── transforms.py ├── exps ├── anet │ └── ViT-L │ │ ├── 14 │ │ └── f16 │ │ │ └── log.txt │ │ └── 14-336px │ │ └── f16 │ │ └── log.txt ├── hmdb51 │ └── ViT-L │ │ └── 14 │ │ └── f16 │ │ └── log.txt ├── k400 │ ├── ViT-B │ │ ├── 16 │ │ │ ├── f16 │ │ │ │ └── log.txt │ │ │ └── f8 │ │ │ │ └── log.txt │ │ └── 32 │ │ │ ├── f16 │ │ │ └── log.txt │ │ │ └── f8 │ │ │ └── log.txt │ └── ViT-L │ │ ├── 14 │ │ └── f8 │ │ │ └── log.txt │ │ └── 14-336px │ │ ├── f32 │ │ └── log.txt │ │ └── f8 │ │ └── log.txt └── ucf101 │ └── ViT-L │ ├── 14 │ └── f16 │ │ └── log.txt │ └── 14-336px │ └── f16 │ └── log.txt ├── lists ├── anet │ ├── anet1.3_label2name.json │ ├── anet_full_for_zeroshot.txt │ ├── anet_train_instance_fps1.txt │ └── anet_val_video_fps1.txt ├── anet1.3_labels.csv ├── hmdb51 │ ├── hmdb_full_for_zeroshot.txt │ ├── train_rgb_split_1.txt │ ├── train_rgb_split_2.txt │ ├── train_rgb_split_3.txt │ ├── val_rgb_split_1.txt │ ├── val_rgb_split_2.txt │ └── val_rgb_split_3.txt ├── hmdb51_labels.csv ├── k400 │ ├── kinetics_rgb_train_se320.txt │ ├── kinetics_rgb_val_se320.txt │ ├── kinetics_video_train_se320.txt │ └── kinetics_video_val_se320.txt ├── k600 │ ├── k160_labels_split1.csv │ ├── k160_labels_split2.csv │ ├── k160_labels_split3.csv │ ├── test_split1_exist.txt │ ├── test_split2_exist.txt │ └── test_split3_exist.txt ├── kinetics_400_labels.csv ├── ucf101 │ ├── train_rgb_split_1.txt │ ├── train_rgb_split_2.txt │ ├── train_rgb_split_3.txt │ ├── ucf_full_for_zeroshot.txt │ ├── val_rgb_split_1.txt │ ├── val_rgb_split_2.txt │ └── val_rgb_split_3.txt └── ucf_labels.csv ├── modules ├── coop.py ├── temporal_modeling.py ├── text_prompt.py └── video_clip.py ├── scripts ├── run_test.sh ├── run_test_zeroshot.sh ├── run_train.sh ├── run_train_multinodes.sh └── run_train_nce.sh ├── teaser.png ├── test.py ├── test_anet.py ├── test_zeroshot.py ├── text4vis.png ├── train.py ├── train_nce.py └── utils ├── Augmentation.py ├── NCELoss.py ├── logger.py ├── lr_scheduler.py ├── solver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | .DS_Store 128 | .vscode/ 129 | exp/ 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Wenhao Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

🔥【AAAI'2023, IJCV'2023】Revisiting Classifier: Transferring Vision-Language Models for Video Recognition

4 | 5 | [![Conference](http://img.shields.io/badge/AAAI-2023-f9f107.svg)](https://ojs.aaai.org/index.php/AAAI/article/view/25386/25158) 6 | [![Journal](http://img.shields.io/badge/IJCV-2023-Bf107.svg)](https://link.springer.com/article/10.1007/s11263-023-01876-w) 7 | 8 | 9 | [Wenhao Wu](https://whwu95.github.io/)1,2, [Zhun Sun](https://scholar.google.co.jp/citations?user=Y-3iZ9EAAAAJ&hl=en)2, [Wanli Ouyang](https://wlouyang.github.io/)3,1 10 | 11 | 12 | 1[The University of Sydney](https://www.sydney.edu.au/), 2[Baidu](https://vis.baidu.com/#/), 3[Shanghai AI Lab](https://www.shlab.org.cn/) 13 | 14 | 15 |
16 | 17 | *** 18 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=transferring-textual-knowledge-for-visual) 19 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/action-recognition-in-videos-on-activitynet)](https://paperswithcode.com/sota/action-recognition-in-videos-on-activitynet?p=transferring-textual-knowledge-for-visual) 20 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/action-recognition-in-videos-on-ucf101)](https://paperswithcode.com/sota/action-recognition-in-videos-on-ucf101?p=transferring-textual-knowledge-for-visual) 21 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/zero-shot-action-recognition-on-kinetics)](https://paperswithcode.com/sota/zero-shot-action-recognition-on-kinetics?p=transferring-textual-knowledge-for-visual) 22 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/zero-shot-action-recognition-on-activitynet)](https://paperswithcode.com/sota/zero-shot-action-recognition-on-activitynet?p=transferring-textual-knowledge-for-visual) 23 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/zero-shot-action-recognition-on-ucf101)](https://paperswithcode.com/sota/zero-shot-action-recognition-on-ucf101?p=transferring-textual-knowledge-for-visual) 24 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/transferring-textual-knowledge-for-visual/zero-shot-action-recognition-on-hmdb51)](https://paperswithcode.com/sota/zero-shot-action-recognition-on-hmdb51?p=transferring-textual-knowledge-for-visual) 25 | 26 | 27 | This is the official implementation of the **AAAI paper** [Revisiting Classifier: Transferring Vision-Language Models for Video Recognition](https://arxiv.org/abs/2207.01297), and **IJCV paper** [Transferring Vision-Language Models for Visual Recognition: A Classifier Perspective](https://link.springer.com/article/10.1007/s11263-023-01876-w). 28 | 29 |
🙋 I also have other cross-modal video projects that may interest you ✨.

30 | 31 | 32 | > [**Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models**](https://arxiv.org/abs/2301.00182)
33 | > Wenhao Wu, Xiaohan Wang, Haipeng Luo, Jingdong Wang, Yi Yang, Wanli Ouyang
34 | > [![Conference](http://img.shields.io/badge/CVPR-2023-f9f107.svg)](https://openaccess.thecvf.com/content/CVPR2023/html/Wu_Bidirectional_Cross-Modal_Knowledge_Exploration_for_Video_Recognition_With_Pre-Trained_Vision-Language_CVPR_2023_paper.html) [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/whwu95/BIKE) 35 | 36 | 37 | > [**Cap4Video: What Can Auxiliary Captions Do for Text-Video Retrieval?**](https://arxiv.org/abs/2301.00184)
38 | > Wenhao Wu, Haipeng Luo, Bo Fang, Jingdong Wang, Wanli Ouyang
39 | > Accepted by CVPR 2023 as 🌟Highlight🌟 | [![Conference](http://img.shields.io/badge/CVPR-2023-f9f107.svg)](https://openaccess.thecvf.com/content/CVPR2023/html/Wu_Cap4Video_What_Can_Auxiliary_Captions_Do_for_Text-Video_Retrieval_CVPR_2023_paper.html) [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/whwu95/Cap4Video)
40 | 41 | 42 |

43 | 44 | 45 | 46 | ## 📣 Updates 47 | - [x] **`Aug 07, 2023`** The extension of Text4Vis has been accepted by **International Journal of Computer Vision (IJCV)**. 48 | - [x] **`Dec 22, 2022`** Models: The pre-trained models & logs. 49 | - [x] **`Nov 30, 2022`** Config: All the configs (general/few-shot/zero-shot video recognition) on Kinetics-400 & 600, ActivityNet, UCF, and HMDB. 50 | - [x] **`Nov 30, 2022`** Code: Zero-shot Evaluation: Half-classes evaluation and Full-classes evaluation. 51 | - [x] **`Nov 28, 2022`** Code: Single-Machine/Multi-Machine Multi-GPU Distributed Training, Distributed testing. 52 | - [x] **`Nov 19, 2022`** 🎉Our paper has been accepted by **AAAI-2023**. 53 | - [x] **`Jul 1, 2022`** 💡Our [initial Arxiv paper](https://arxiv.org/abs/2207.01297v1) is released. 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | ## 🌈 Overview 62 | In our Text4Vis, we revise the role of the linear classifier and replace the classifier with the different knowledge from pre-trained model. We utilize the well-pretrained language model to generate good semantic target for efficient transferring learning. 63 | 64 | ![1](teaser.png) 65 | ![2](text4vis.png) 66 | 67 | ## Content 68 | - [Prerequisites](#prerequisites) 69 | - [Data Preparation](#data-preparation) 70 | - [Model Zoo](#model-zoo) 71 | - [Training](#training) 72 | - [Testing](#testing) 73 | - [BibTeX & Citation](#bibtex) 74 | - [Acknowledgment](#acknowledgment) 75 | 76 | 77 | 78 | ## 📕 Prerequisites 79 | The code is built with following libraries: 80 | 81 | - [PyTorch](https://pytorch.org/) >= 1.8 82 | - RandAugment 83 | - pprint 84 | - tqdm 85 | - dotmap 86 | - yaml 87 | - csv 88 | - Optional: decord (for on-the-fly video training) 89 | - Optional: torchnet (for mAP evaluation on ActivityNet) 90 | 91 | 92 | ## 📚 Data Preparation 93 | 94 | #### Video Loader 95 | **(Recommend)** To train all of our models, we extract videos into frames for fast reading. Please refer to [MVFNet](https://github.com/whwu95/MVFNet/blob/main/data_process/DATASETS.md) repo for the detaied guide of data processing. 96 | The annotation file is a text file with multiple lines, and each line indicates the directory to frames of a video, total frames of the video and the label of a video, which are split with a whitespace. Here is the format: 97 | ```sh 98 | abseiling/-7kbO0v4hag_000107_000117 300 0 99 | abseiling/-bwYZwnwb8E_000013_000023 300 0 100 | ``` 101 | 102 | **(Optional)** We can also decode the videos in an online fashion using [decord](https://github.com/dmlc/decord). This manner should work but are not tested. All of the models offered have been trained using offline frames. Example of annotation: 103 | ```sh 104 | abseiling/-7kbO0v4hag_000107_000117.mp4 0 105 | abseiling/-bwYZwnwb8E_000013_000023.mp4 0 106 | ``` 107 | 108 | #### Annotation 109 | Annotation information consists of two parts: video label, and category description. 110 | 111 | - Video Label: As mentioned above, this part is same as the traditional video recognition. Please refer to `lists/k400/kinetics_rgb_train_se320.txt` for the format. 112 | - Category Description: We also need a textual description for each video category. Please refer to `lists/kinetics_400_labels.csv` for the format. 113 | 114 | 115 | 116 | ## 📱 Model Zoo 117 | 118 | Here we provide some off-the-shelf pre-trained checkpoints of our models in the following tables. 119 | 120 | *#Frame = #input_frame x #spatial crops x #temporal clips* 121 | #### Kinetics-400 122 | 123 | | Architecture |#Frame | Top-1 Acc.(%) | checkpoint | Train log| config| 124 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:| 125 | | ViT-B/32 | 8x3x4 | 80.0 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-32-f8.pt) | [log](exps/k400/ViT-B/32/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitb-32-f8.yaml) | 126 | | ViT-B/32 | 16x3x4 | 80.5 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-32-f16.pt) | [log](exps/k400/ViT-B/32/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitb-32-f16.yaml) | 127 | | ViT-B/16 | 8x3x4 | 82.9 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-16-f8.pt) | [log](exps/k400/ViT-B/16/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitb-16-f8.yaml) | 128 | | ViT-B/16 | 16x3x4 | 83.6 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-16-f16.pt)| [log](exps/k400/ViT-B/16/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitb-16-f16.yaml) | 129 | | ViT-L/14* | 8x3x4 | 86.4 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/k400/ViT-L/14/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-f8.yaml) | 130 | | ViT-L/14-336 | 8x3x4 | 87.1 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/k400/ViT-L/14-336px/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-336-f8.yaml) | 131 | | ViT-L/14-336 | 32x3x1 | 87.8 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/k400/ViT-L/14-336px/f32/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-336-f32.yaml) | 132 | 133 | *Note: * indicates that this ViT-L model is used for the zero-shot evaluation on UCF, HMDB, ActivityNet and Kinetics-600.* 134 | 135 | #### ActivityNet 136 | | Architecture |#Frame | mAP (%) | checkpoint | Train log| config| 137 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:| 138 | | ViT-L/14 | 16x1x1 | 96.5 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [config](configs/anet/anet_k400_finetune.yaml) | 139 | | ViT-L/14-336 | 16x1x1 | 96.9 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/anet/ViT-L/14-336px/f16/log.txt) | [config](configs/anet/anet_k400_finetune_336.yaml) | 140 | 141 | #### UCF-101 142 | | Architecture |#Frame | Top-1 Acc. (%) | checkpoint | Train log| config| 143 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:| 144 | | ViT-L/14 | 16x1x1 | 98.1 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/ucf101/ViT-L/14/f16/log.txt) | [config](configs/ucf101/ucf_k400_finetune.yaml) | 145 | 146 | 147 | #### HMDB-51 148 | | Architecture |#Frame | Top-1 Acc. (%) | checkpoint | Train log| config| 149 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:| 150 | | ViT-L/14 | 16x1x1 | 81.3 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/hmdb51/ViT-L/14/f16/log.txt) | [config](configs/hmdb51/hmdb_k400_finetune.yaml) | 151 | 152 | 153 | 154 | ## 🚀 Training 155 | This implementation supports Multi-GPU `DistributedDataParallel` training, which is faster and simpler than `DataParallel` used in [ActionCLIP](https://github.com/sallymmx/actionclip). 156 | 157 | - **Single Machine**: To train our model on Kinetics-400 with 8 GPUs in *Single Machine*, you can run: 158 | ```sh 159 | # For example, train the 8 Frames ViT-B/32. 160 | sh scripts/run_train.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml 161 | ``` 162 | 163 | - **Mulitple Machines**: We also provide the script to train larger model with *Mulitple Machines* (e.g., 2 machines and 16 GPUs), you can run: 164 | ```sh 165 | # For example, we train the 8 Frames ViT-L/14 with 2 machines as follows: 166 | # For first machine, you need to set the ip of your first machine as the --master_addr, --nnodes is 2. 167 | # Compared with the Single-Machine training script, only one node_id needs to be added. 168 | sh scripts/run_train_multinodes.sh configs/k400/k400_train_rgb_vitl-14-f8.yaml 0 169 | 170 | # For second machine, --master_addr is still the ip of your first machine 171 | sh scripts/run_train_multinodes.sh configs/k400/k400_train_rgb_vitl-14-f8.yaml 1 172 | ``` 173 | 174 | - **Few-shot Recognition**: To train our model under *Few-shot* scenario, you just need to add one line in the general config file: 175 | ```sh 176 | # You can refer to config/k400/k400_few_shot.yaml 177 | data: 178 | ... # general configurations 179 | shot: 2 # i.e., 2-shot setting 180 | ``` 181 | 182 | 183 | ## ⚡ Testing 184 | We support single view validation and multi-view (4x3 views) validation. 185 | 186 | #### General/Few-shot Video Recognition 187 | ```sh 188 | # Single view evaluation. e.g., ViT-B/32 8 Frames on Kinetics-400 189 | sh scripts/run_test.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml exp/k400/ViT-B/32/f8/last_model.pt 190 | 191 | # Multi-view evalition (4clipsx3crops). e.g., ViT-B/32 8 Frames on Kinetics-400 192 | sh scripts/run_test.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml exp/k400/ViT-B/32/f8/last_model.pt --test_crops 3 --test_clips 4 193 | ``` 194 | 195 | 196 | #### Zero-shot Evaluation 197 | 198 | We use the Kinetics-400 pre-trained model (e.g., ViT-L/14 with 8 frames) to perform cross-dataset zero-shot evaluation, i.e., UCF101, HMDB51, ActivityNet, Kinetics-600. 199 | 200 | - Half-classes Evaluation: A traditional evaluation protocol involves selecting half of the test dataset's classes, repeating the process ten times, and reporting the mean accuracy with a standard deviation of ten times. 201 | 202 | 203 | - Full-classes Evaluation: Perform evaluation on the entire dataset. 204 | 205 | ```sh 206 | # On ActivityNet: reporting the half-classes and full-classes results 207 | sh scripts/run_test_zeroshot.sh configs/anet/anet_zero_shot.yaml exp/k400/ViT-L/14/f8/last_model.pt 208 | 209 | # On UCF101: reporting the half-classes and full-classes results 210 | sh scripts/run_test_zeroshot.sh configs/ucf101/ucf_zero_shot.yaml exp/k400/ViT-L/14/f8/last_model.pt 211 | 212 | # On HMDB51: reporting the half-classes and full-classes results 213 | sh scripts/run_test_zeroshot.sh configs/hmdb51/hmdb_zero_shot.yaml exp/k400/ViT-L/14/f8/last_model.pt 214 | 215 | # On Kinetics-600: manually calculating the mean accuracy with standard deviation of three splits. 216 | sh scripts/run_test.sh configs/k600/k600_zero_shot_split1.yaml exp/k400/ViT-L/14/f8/last_model.pt 217 | sh scripts/run_test.sh configs/k600/k600_zero_shot_split2.yaml exp/k400/ViT-L/14/f8/last_model.pt 218 | sh scripts/run_test.sh configs/k600/k600_zero_shot_split3.yaml exp/k400/ViT-L/14/f8/last_model.pt 219 | ``` 220 | 221 | 222 | 223 | 224 | ## 📌 BibTeX & Citation 225 | If you find this repository useful, please star🌟 this repo and cite📑 our paper: 226 | 227 | ```bibtex 228 | @inproceedings{wu2023revisiting, 229 | title={Revisiting classifier: Transferring vision-language models for video recognition}, 230 | author={Wu, Wenhao and Sun, Zhun and Ouyang, Wanli}, 231 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 232 | volume={37}, 233 | number={3}, 234 | pages={2847--2855}, 235 | year={2023} 236 | } 237 | 238 | @article{wu2023transferring, 239 | title={Transferring vision-language models for visual recognition: A classifier perspective}, 240 | author={Wu, Wenhao and Sun, Zhun and Song, Yuxin and Wang, Jingdong and Ouyang, Wanli}, 241 | journal={International Journal of Computer Vision}, 242 | pages={1--18}, 243 | year={2023}, 244 | publisher={Springer} 245 | } 246 | ``` 247 | 248 | If you also find [BIKE](https://github.com/whwu95/BIKE) useful, please cite the paper: 249 | 250 | ```bibtex 251 | @inproceedings{bike, 252 | title={Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models}, 253 | author={Wu, Wenhao and Wang, Xiaohan and Luo, Haipeng and Wang, Jingdong and Yang, Yi and Ouyang, Wanli}, 254 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 255 | year={2023} 256 | } 257 | ``` 258 | 259 | 260 | 261 | ## 🎗️ Acknowledgement 262 | 263 | This repository is built based on [ActionCLIP](https://github.com/sallymmx/actionclip) and [CLIP](https://github.com/openai/CLIP). Sincere thanks to their wonderful works. 264 | 265 | 266 | ## 👫 Contact 267 | For any question, please file an issue. 268 | 269 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/Text4Vis/d61b34d0208a03ce6146edcc51033b2a040cb249/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14-336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" 40 | } 41 | 42 | 43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load( 95 | name: str, 96 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 97 | jit=True, 98 | internal_modeling=False, joint_st=False, T=8, dropout=0., 99 | emb_dropout=0., 100 | pretrain=True): 101 | """Load a CLIP model 102 | 103 | Parameters 104 | ---------- 105 | name : str 106 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 107 | 108 | device : Union[str, torch.device] 109 | The device to put the loaded model 110 | 111 | jit : bool 112 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 113 | 114 | download_root: str 115 | path to download the model files; by default, it uses "~/.cache/clip" 116 | 117 | Returns 118 | ------- 119 | model : torch.nn.Module 120 | The CLIP model 121 | 122 | preprocess : Callable[[PIL.Image], torch.Tensor] 123 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 124 | """ 125 | if name in _MODELS: 126 | model_path = _download(_MODELS[name]) 127 | elif os.path.isfile(name): 128 | model_path = name 129 | else: 130 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 131 | 132 | try: 133 | # loading JIT archive 134 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 135 | state_dict = None 136 | except RuntimeError: 137 | # loading saved state dict 138 | if jit: 139 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 140 | jit = False 141 | state_dict = torch.load(model_path, map_location="cpu") 142 | 143 | if not jit: 144 | model = build_model(state_dict or model.state_dict(), joint=joint_st, tm=internal_modeling, T=T, dropout=dropout, emb_dropout=emb_dropout, pretrain=pretrain).to(device) 145 | if str(device) == "cpu": 146 | model.float() 147 | return model, model.state_dict() 148 | 149 | # patch the device names 150 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 151 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 152 | 153 | def patch_device(module): 154 | try: 155 | graphs = [module.graph] if hasattr(module, "graph") else [] 156 | except RuntimeError: 157 | graphs = [] 158 | 159 | if hasattr(module, "forward1"): 160 | graphs.append(module.forward1.graph) 161 | 162 | for graph in graphs: 163 | for node in graph.findAllNodes("prim::Constant"): 164 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 165 | node.copyAttributes(device_node) 166 | 167 | model.apply(patch_device) 168 | 169 | if str(device) == "cpu": 170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 172 | float_node = float_input.node() 173 | 174 | def patch_float(module): 175 | try: 176 | graphs = [module.graph] if hasattr(module, "graph") else [] 177 | except RuntimeError: 178 | graphs = [] 179 | 180 | if hasattr(module, "forward1"): 181 | graphs.append(module.forward1.graph) 182 | 183 | for graph in graphs: 184 | for node in graph.findAllNodes("aten::to"): 185 | inputs = list(node.inputs()) 186 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 187 | if inputs[i].node()["value"] == 5: 188 | inputs[i].node().copyAttributes(float_node) 189 | 190 | model.apply(patch_float) 191 | patch_float(model.encode_image) 192 | patch_float(model.encode_text) 193 | 194 | model.float() 195 | 196 | return model, _transform(model.input_resolution.item()) 197 | 198 | 199 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 200 | """ 201 | Returns the tokenized representation of given input string(s) 202 | 203 | Parameters 204 | ---------- 205 | texts : Union[str, List[str]] 206 | An input string or a list of input strings to tokenize 207 | 208 | context_length : int 209 | The context length to use; all CLIP models use 77 as the context length 210 | 211 | truncate: bool 212 | Whether to truncate the text in case its encoding is longer than the context length 213 | 214 | Returns 215 | ------- 216 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 217 | """ 218 | if isinstance(texts, str): 219 | texts = [texts] 220 | 221 | sot_token = _tokenizer.encoder["<|startoftext|>"] 222 | eot_token = _tokenizer.encoder["<|endoftext|>"] 223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | 226 | for i, tokens in enumerate(all_tokens): 227 | if len(tokens) > context_length: 228 | if truncate: 229 | tokens = tokens[:context_length] 230 | tokens[-1] = eot_token 231 | else: 232 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 233 | result[i, :len(tokens)] = torch.tensor(tokens) 234 | 235 | return result 236 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/anet/anet_few_shot.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: anet 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 200 12 | image_tmpl: 'image_{:06d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/anet_instance_frames_v1.3_train_vids_fps1' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/activitynet_val_resize_img_256_340_fps1' 15 | train_list: 'lists/anet/anet_train_instance_fps1.txt' 16 | val_list: 'lists/anet/anet_val_video_fps1.txt' # 17 | label_list: 'lists/anet1.3_labels.csv' 18 | input_size: 224 19 | random_shift: True 20 | shot: 2 21 | network: 22 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 23 | init: True 24 | drop_out: 0.0 25 | emb_dropout: 0.0 26 | type: clip_anet 27 | sim_header: None 28 | drop: 0 29 | solver: 30 | type: cosine 31 | epochs: 30 32 | start_epoch: 0 33 | optim: adamw 34 | lr: 5.e-5 35 | lr_warmup_step: 5 36 | weight_decay: 0.2 37 | loss_type: CE 38 | evaluate: False 39 | clip_ratio: 0.1 40 | grad_accumulation_steps: 1 41 | logging: 42 | print_freq: 10 43 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/anet/anet_k400_finetune.yaml: -------------------------------------------------------------------------------- 1 | pretrain: exp_sota/k400/ViT-L/14/f16/last_model.pt 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: anet 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 4 10 | workers: 4 11 | num_classes: 200 12 | image_tmpl: 'image_{:06d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/anet_instance_frames_v1.3_train_vids_fps1' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/activitynet_val_resize_img_256_340_fps1' 15 | train_list: 'lists/anet/anet_train_instance_fps1.txt' 16 | val_list: 'lists/anet/anet_val_video_fps1.txt' # 17 | label_list: 'lists/anet1.3_labels.csv' 18 | input_size: 224 19 | random_shift: True 20 | network: 21 | arch: ViT-L/14 22 | init: True 23 | drop_out: 0.0 24 | emb_dropout: 0.0 25 | type: clip_anet 26 | sim_header: Transf 27 | drop: 0 28 | solver: 29 | type: cosine 30 | epochs: 30 31 | start_epoch: 0 32 | optim: adamw 33 | lr: 5.e-5 34 | lr_warmup_step: 5 35 | weight_decay: 0.2 36 | loss_type: CE 37 | evaluate: False 38 | clip_ratio: 0.1 39 | grad_accumulation_steps: 2 40 | logging: 41 | print_freq: 10 42 | eval_freq: 5 43 | 44 | -------------------------------------------------------------------------------- /configs/anet/anet_k400_finetune_336.yaml: -------------------------------------------------------------------------------- 1 | pretrain: exp_sota/k400/ViT-L/14-336px/f16/last_model.pt 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: anet 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 1 10 | workers: 4 11 | num_classes: 200 12 | image_tmpl: 'image_{:06d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/anet_instance_frames_v1.3_train_vids_fps1' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/activitynet_val_resize_img_256_340_fps1' 15 | train_list: 'lists/anet/anet_train_instance_fps1.txt' 16 | val_list: 'lists/anet/anet_val_video_fps1.txt' # 17 | label_list: 'lists/anet1.3_labels.csv' 18 | input_size: 336 19 | random_shift: True 20 | network: 21 | arch: ViT-L/14-336px 22 | init: True 23 | drop_out: 0.0 24 | emb_dropout: 0.0 25 | type: clip_anet 26 | sim_header: Transf 27 | drop: 0 28 | solver: 29 | type: cosine 30 | epochs: 30 31 | start_epoch: 0 32 | optim: adamw 33 | lr: 5.e-5 34 | lr_warmup_step: 5 35 | weight_decay: 0.2 36 | loss_type: CE 37 | evaluate: False 38 | clip_ratio: 0.1 39 | grad_accumulation_steps: 8 40 | logging: 41 | print_freq: 10 42 | eval_freq: 5 43 | 44 | -------------------------------------------------------------------------------- /configs/anet/anet_zero_shot.yaml: -------------------------------------------------------------------------------- 1 | seed: 1024 2 | data: 3 | dataset: anet 4 | modality: RGB 5 | num_segments: 8 6 | seg_length: 1 7 | batch_size: 32 8 | workers: 4 9 | num_classes: 200 10 | image_tmpl: 'image_{:06d}.jpg' 11 | val_root: /bpfs/v2_mnt/VIS/wuwenhao/anet 12 | val_list: 'lists/anet/anet_full_for_zeroshot.txt' 13 | label_list: 'lists/anet1.3_labels.csv' 14 | index_bias: 1 15 | input_size: 224 16 | network: 17 | arch: ViT-L/14 #ViT-L/14 #ViT-B/32 ViT-B/16 18 | init: True 19 | drop_out: 0.0 20 | emb_dropout: 0.0 21 | type: clip_anet 22 | sim_header: Transf 23 | logging: 24 | print_freq: 10 25 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/hmdb51/hmdb_few_shot.yaml: -------------------------------------------------------------------------------- 1 | pretrain: 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: hmdb51 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 51 12 | image_tmpl: 'image_{:06d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 15 | train_list: 'lists/hmdb51/train_rgb_split_1.txt' 16 | val_list: 'lists/hmdb51/val_rgb_split_1.txt' 17 | label_list: 'lists/hmdb51_labels.csv' 18 | input_size: 224 19 | random_shift: True 20 | shot: 2 21 | network: 22 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 23 | init: True 24 | drop_out: 0.0 25 | emb_dropout: 0.0 26 | type: clip_hmdb 27 | sim_header: None 28 | drop: 0 29 | solver: 30 | type: cosine 31 | epochs: 30 32 | start_epoch: 0 33 | optim: adamw 34 | lr: 5.e-5 35 | lr_warmup_step: 5 36 | weight_decay: 0.2 37 | loss_type: CE 38 | evaluate: False 39 | clip_ratio: 0.1 40 | grad_accumulation_steps: 1 41 | logging: 42 | print_freq: 10 43 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/hmdb51/hmdb_k400_finetune.yaml: -------------------------------------------------------------------------------- 1 | pretrain: exp_sota/k400/ViT-L/14/f16/last_model.pt 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: hmdb51 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 4 10 | workers: 4 11 | num_classes: 51 12 | image_tmpl: 'image_{:06d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 15 | train_list: 'lists/hmdb51/train_rgb_split_1.txt' 16 | val_list: 'lists/hmdb51/val_rgb_split_1.txt' 17 | label_list: 'lists/hmdb51_labels.csv' 18 | input_size: 224 19 | random_shift: True 20 | network: 21 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 22 | init: True 23 | drop_out: 0.0 24 | emb_dropout: 0.0 25 | type: clip_hmdb 26 | sim_header: Transf 27 | drop: 0 28 | solver: 29 | type: cosine 30 | epochs: 30 31 | start_epoch: 0 32 | optim: adamw 33 | lr: 5.e-5 34 | lr_warmup_step: 5 35 | weight_decay: 0.2 36 | loss_type: CE 37 | evaluate: False 38 | clip_ratio: 0.1 39 | grad_accumulation_steps: 2 40 | logging: 41 | print_freq: 10 42 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/hmdb51/hmdb_k400_finetune_336.yaml: -------------------------------------------------------------------------------- 1 | pretrain: exp_sota/k400/ViT-L/14-336px/f16/last_model.pt 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: hmdb51 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 1 10 | workers: 4 11 | num_classes: 51 12 | image_tmpl: 'image_{:06d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 15 | train_list: 'lists/hmdb51/train_rgb_split_1.txt' 16 | val_list: 'lists/hmdb51/val_rgb_split_1.txt' 17 | label_list: 'lists/hmdb51_labels.csv' 18 | input_size: 336 19 | random_shift: True 20 | network: 21 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16 22 | init: True 23 | drop_out: 0.0 24 | emb_dropout: 0.0 25 | type: clip_hmdb 26 | sim_header: Transf 27 | drop: 0 28 | solver: 29 | type: cosine 30 | epochs: 30 31 | start_epoch: 0 32 | optim: adamw 33 | lr: 5.e-5 34 | lr_warmup_step: 5 35 | weight_decay: 0.2 36 | loss_type: CE 37 | evaluate: False 38 | clip_ratio: 0.1 39 | grad_accumulation_steps: 8 40 | logging: 41 | print_freq: 10 42 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/hmdb51/hmdb_zero_shot.yaml: -------------------------------------------------------------------------------- 1 | seed: 1024 2 | data: 3 | dataset: hmdb51 4 | modality: RGB 5 | num_segments: 8 6 | seg_length: 1 7 | batch_size: 16 8 | workers: 8 9 | num_classes: 51 10 | image_tmpl: 'image_{:06d}.jpg' 11 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340' 12 | val_list: 'lists/hmdb51/hmdb_full_for_zeroshot.txt' 13 | label_list: 'lists/hmdb51_labels.csv' 14 | index_bias: 1 15 | input_size: 224 16 | network: 17 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 18 | init: True 19 | drop_out: 0.0 20 | emb_dropout: 0.0 21 | type: clip_hmdb 22 | sim_header: Transf 23 | logging: 24 | print_freq: 10 25 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_few_shot.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | index_bias: 1 19 | input_size: 224 20 | randaug: 21 | N: 2 #2 22 | M: 9 #9 23 | random_shift: True 24 | shot: 2 25 | network: 26 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 27 | init: True 28 | drop_out: 0.0 29 | emb_dropout: 0.0 30 | type: clip_k400 31 | sim_header: None 32 | drop: 0 33 | solver: 34 | type: cosine 35 | epochs: 30 36 | start_epoch: 0 37 | epoch_offset: 0 38 | optim: adamw 39 | lr: 5.e-5 40 | lr_warmup_step: 5 41 | weight_decay: 0.2 42 | loss_type: CE 43 | evaluate: False 44 | clip_ratio: 0.1 45 | grad_accumulation_steps: 1 46 | logging: 47 | print_freq: 10 48 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_rn50.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 32 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 224 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: RN50 # RN50 RN101 RN50x4 RN50x16 RN50x64 25 | init: True 26 | tm: False # False tsm 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 1 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitb-16-f16.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 16 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 224 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-B/16 #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 2 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitb-16-f8.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 32 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 224 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-B/16 #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 1 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitb-32-f16.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 32 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 224 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-B/32 #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 1 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitb-32-f8.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 32 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 224 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-B/32 #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 1 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitl-14-336-f32.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 32 8 | seg_length: 1 9 | batch_size: 1 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 336 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 8 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitl-14-336-f8.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 2 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 336 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 4 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k400/k400_train_rgb_vitl-14-f8.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k400 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 400 12 | image_tmpl: 'img_{:05d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt' 15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv 16 | val_list: lists/k400/kinetics_rgb_val_se320.txt 17 | label_list: 'lists/kinetics_400_labels.csv' 18 | input_size: 224 19 | randaug: 20 | N: 2 #2 21 | M: 9 #9 22 | random_shift: True 23 | network: 24 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 25 | init: True 26 | tm: False # False tsm tokent1d tokenshift 27 | drop_out: 0.0 28 | emb_dropout: 0.0 29 | type: clip_k400 30 | sim_header: Transf # Transf None 31 | joint_st: False 32 | drop: 0 33 | fix_text: True 34 | fix_video: False 35 | solver: 36 | type: cosine 37 | epochs: 30 38 | start_epoch: 0 39 | epoch_offset: 0 40 | optim: adamw 41 | lr: 5.e-5 42 | lr_warmup_step: 5 43 | weight_decay: 0.2 44 | loss_type: CE 45 | evaluate: False 46 | clip_ratio: 0.1 47 | grad_accumulation_steps: 2 48 | logging: 49 | print_freq: 10 50 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k600/k600_zero_shot_split1.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k600 6 | modality: video 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 160 12 | image_tmpl: 'img_{:05d}.jpg' 13 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/k600_test_video' 14 | val_list: lists/k600/test_split1_exist.txt 15 | label_list: lists/k600/k160_labels_split1.csv 16 | index_bias: 1 17 | input_size: 224 18 | network: 19 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 20 | init: True 21 | drop_out: 0.0 22 | emb_dropout: 0.0 23 | type: clip_k600 24 | sim_header: Transf 25 | drop: 0 26 | logging: 27 | print_freq: 10 28 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k600/k600_zero_shot_split2.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k600 6 | modality: video 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 160 12 | image_tmpl: 'img_{:05d}.jpg' 13 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/k600_test_video' 14 | val_list: lists/k600/test_split2_exist.txt 15 | label_list: lists/k600/k160_labels_split2.csv 16 | index_bias: 1 17 | input_size: 224 18 | network: 19 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 20 | init: True 21 | drop_out: 0.0 22 | emb_dropout: 0.0 23 | type: clip_k600 24 | sim_header: Transf 25 | drop: 0 26 | logging: 27 | print_freq: 10 28 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/k600/k600_zero_shot_split3.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | pretrain: 3 | seed: 1024 4 | data: 5 | dataset: k600 6 | modality: video 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 160 12 | image_tmpl: 'img_{:05d}.jpg' 13 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/k600_test_video' 14 | val_list: lists/k600/test_split3_exist.txt 15 | label_list: lists/k600/k160_labels_split3.csv 16 | index_bias: 1 17 | input_size: 224 18 | network: 19 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 20 | init: True 21 | drop_out: 0.0 22 | emb_dropout: 0.0 23 | type: clip_k600 24 | sim_header: Transf 25 | drop: 0 26 | logging: 27 | print_freq: 10 28 | eval_freq: 1 -------------------------------------------------------------------------------- /configs/ucf101/ucf_few_shot.yaml: -------------------------------------------------------------------------------- 1 | pretrain: 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: ucf101 6 | modality: RGB 7 | num_segments: 8 8 | seg_length: 1 9 | batch_size: 8 10 | workers: 4 11 | num_classes: 101 12 | image_tmpl: 'image_{:04d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 15 | train_list: 'lists/ucf101/train_rgb_split_1.txt' 16 | val_list: 'lists/ucf101/val_rgb_split_1.txt' 17 | label_list: 'lists/ucf_labels.csv' 18 | input_size: 224 19 | random_shift: True 20 | shot: 2 21 | network: 22 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 23 | init: True 24 | drop_out: 0.0 25 | emb_dropout: 0.0 26 | type: clip_hmdb 27 | sim_header: None 28 | drop: 0 29 | solver: 30 | type: cosine 31 | epochs: 30 32 | start_epoch: 0 33 | optim: adamw 34 | lr: 5.e-5 35 | lr_warmup_step: 5 36 | weight_decay: 0.2 37 | loss_type: CE 38 | evaluate: False 39 | clip_ratio: 0.1 40 | grad_accumulation_steps: 1 41 | logging: 42 | print_freq: 10 43 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/ucf101/ucf_k400_finetune.yaml: -------------------------------------------------------------------------------- 1 | pretrain: exp_sota/k400/ViT-L/14/f16/last_model.pt 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: ucf101 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 4 10 | workers: 4 11 | num_classes: 101 12 | image_tmpl: 'image_{:04d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 15 | train_list: 'lists/ucf101/train_rgb_split_1.txt' 16 | val_list: 'lists/ucf101/val_rgb_split_1.txt' 17 | label_list: 'lists/ucf_labels.csv' 18 | input_size: 224 19 | random_shift: True 20 | network: 21 | arch: ViT-L/14 #ViT-B/32 ViT-B/16 22 | init: True 23 | drop_out: 0.0 24 | emb_dropout: 0.0 25 | type: clip_ucf 26 | sim_header: Transf 27 | drop: 0 28 | solver: 29 | type: cosine 30 | epochs: 30 31 | start_epoch: 0 32 | optim: adamw 33 | lr: 5.e-5 34 | lr_warmup_step: 5 35 | weight_decay: 0.2 36 | loss_type: CE 37 | evaluate: False 38 | clip_ratio: 0.1 39 | grad_accumulation_steps: 2 40 | logging: 41 | print_freq: 10 42 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/ucf101/ucf_k400_finetune_336.yaml: -------------------------------------------------------------------------------- 1 | pretrain: exp_sota/k400/ViT-L/14-336px/f16/last_model.pt 2 | resume: 3 | seed: 1024 4 | data: 5 | dataset: ucf101 6 | modality: RGB 7 | num_segments: 16 8 | seg_length: 1 9 | batch_size: 1 10 | workers: 4 11 | num_classes: 101 12 | image_tmpl: 'image_{:04d}.jpg' 13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 15 | train_list: 'lists/ucf101/train_rgb_split_1.txt' 16 | val_list: 'lists/ucf101/val_rgb_split_1.txt' 17 | label_list: 'lists/ucf_labels.csv' 18 | input_size: 336 19 | random_shift: True 20 | network: 21 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16 22 | init: True 23 | drop_out: 0.0 24 | emb_dropout: 0.0 25 | type: clip_ucf 26 | sim_header: Transf 27 | drop: 0 28 | solver: 29 | type: cosine 30 | epochs: 30 31 | start_epoch: 0 32 | optim: adamw 33 | lr: 5.e-5 34 | lr_warmup_step: 5 35 | weight_decay: 0.2 36 | loss_type: CE 37 | evaluate: False 38 | clip_ratio: 0.1 39 | grad_accumulation_steps: 8 40 | logging: 41 | print_freq: 10 42 | eval_freq: 5 -------------------------------------------------------------------------------- /configs/ucf101/ucf_zero_shot.yaml: -------------------------------------------------------------------------------- 1 | seed: 1024 2 | data: 3 | dataset: ucf101 4 | modality: RGB 5 | num_segments: 8 6 | seg_length: 1 7 | batch_size: 16 8 | workers: 8 9 | num_classes: 101 10 | image_tmpl: 'image_{:04d}.jpg' 11 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames' 12 | val_list: 'lists/ucf101/ucf_full_for_zeroshot.txt' # 13 | label_list: 'lists/ucf_labels.csv' 14 | index_bias: 1 15 | input_size: 224 16 | network: 17 | arch: ViT-L/14 18 | init: True 19 | drop_out: 0.0 20 | emb_dropout: 0.0 21 | type: clip_ucf 22 | sim_header: Transf 23 | logging: 24 | print_freq: 10 25 | eval_freq: 1 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import decord 4 | import os 5 | import numpy as np 6 | from numpy.random import randint 7 | import io 8 | import pandas as pd 9 | import random 10 | from PIL import Image 11 | import math 12 | import copy 13 | 14 | 15 | class VideoRecord(object): 16 | def __init__(self, row): 17 | self._data = row 18 | 19 | @property 20 | def path(self): 21 | return self._data[0] 22 | 23 | @property 24 | def num_frames(self): 25 | return int(self._data[1]) 26 | 27 | @property 28 | def label(self): 29 | return int(self._data[-1]) 30 | 31 | 32 | class Video_dataset(data.Dataset): 33 | def __init__(self, root_path, list_file, labels_file, 34 | num_segments=1, modality='RGB', new_length=1, 35 | image_tmpl='img_{:05d}.jpg', transform=None, 36 | random_shift=True, test_mode=False, 37 | index_bias=1, dense_sample=False, test_clips=3): 38 | 39 | self.root_path = root_path 40 | self.list_file = list_file 41 | self.num_segments = num_segments 42 | self.modality = modality 43 | self.seg_length = new_length 44 | self.image_tmpl = image_tmpl 45 | self.transform = transform 46 | self.random_shift = random_shift 47 | self.test_mode = test_mode 48 | self.loop=False 49 | self.index_bias = index_bias 50 | self.labels_file = labels_file 51 | self.sample_range = 128 52 | self.dense_sample = dense_sample # using dense sample as I3D 53 | self.test_clips = test_clips 54 | if self.dense_sample: 55 | print('=> Using dense sample for the dataset...') 56 | 57 | if self.index_bias is None: 58 | if self.image_tmpl == "frame{:d}.jpg": 59 | self.index_bias = 0 60 | else: 61 | self.index_bias = 1 62 | self._parse_list() 63 | self.initialized = False 64 | 65 | @property 66 | def total_length(self): 67 | return self.num_segments * self.seg_length 68 | 69 | @property 70 | def classes(self): 71 | classes_all = pd.read_csv(self.labels_file) 72 | return classes_all.values.tolist() 73 | 74 | def _parse_list(self): 75 | # check the frame number is large >3: 76 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 77 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config 78 | if not self.test_mode: 79 | tmp = [item for item in tmp if int(item[1]) >= 8] 80 | self.video_list = [VideoRecord(item) for item in tmp] 81 | print('video number:%d' % (len(self.video_list))) 82 | 83 | def _sample_indices(self, video_list): 84 | if self.dense_sample: 85 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 86 | interval = self.sample_range // self.num_segments 87 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 88 | base_offsets = np.arange(self.num_segments) * interval 89 | offsets = (base_offsets + start_idx) % len(video_list) 90 | return np.array(offsets) + self.index_bias 91 | else: 92 | if len(video_list) <= self.total_length: 93 | if self.loop: 94 | return np.mod(np.arange( 95 | self.total_length) + randint(len(video_list) // 2), 96 | len(video_list)) + self.index_bias 97 | offsets = np.concatenate(( 98 | np.arange(len(video_list)), 99 | randint(len(video_list), 100 | size=self.total_length - len(video_list)))) 101 | return np.sort(offsets) + self.index_bias 102 | offsets = list() 103 | ticks = [i * len(video_list) // self.num_segments 104 | for i in range(self.num_segments + 1)] 105 | 106 | for i in range(self.num_segments): 107 | tick_len = ticks[i + 1] - ticks[i] 108 | tick = ticks[i] 109 | if tick_len >= self.seg_length: 110 | tick += randint(tick_len - self.seg_length + 1) 111 | offsets.extend([j for j in range(tick, tick + self.seg_length)]) 112 | return np.array(offsets) + self.index_bias 113 | 114 | def _get_val_indices(self, video_list): 115 | if self.dense_sample: 116 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 117 | t_stride = self.sample_range // self.num_segments 118 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 119 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)] 120 | return np.array(offsets) + self.index_bias 121 | else: 122 | tick = len(video_list) / float(self.num_segments) 123 | offsets = [int(tick * x) % len(video_list) for x in range(self.num_segments)] 124 | return np.array(offsets) + self.index_bias 125 | 126 | 127 | def _get_test_indices(self, video_list): 128 | if self.dense_sample: 129 | # multi-clip for dense sampling 130 | num_clips = self.test_clips 131 | sample_pos = max(0, len(video_list) - self.sample_range) 132 | interval = self.sample_range // self.num_segments 133 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)] 134 | base_offsets = np.arange(self.num_segments) * interval 135 | offsets = [] 136 | for start_idx in start_list: 137 | offsets.extend((base_offsets + start_idx) % len(video_list)) 138 | return np.array(offsets) + self.index_bias 139 | else: 140 | # multi-clip for uniform sampling 141 | num_clips = self.test_clips 142 | tick = len(video_list) / float(self.num_segments) 143 | start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int) 144 | offsets = [] 145 | for start_idx in start_list.tolist(): 146 | offsets += [ 147 | int(start_idx + tick * x) % len(video_list) 148 | for x in range(self.num_segments) 149 | ] 150 | return np.array(offsets) + self.index_bias 151 | 152 | 153 | 154 | def _decord_decode(self, video_path): 155 | try: 156 | container = decord.VideoReader(video_path) 157 | except Exception as e: 158 | print("Failed to decode {} with exception: {}".format( 159 | video_path, e)) 160 | return None 161 | 162 | return container 163 | 164 | def __getitem__(self, index): 165 | # decode frames to video_list 166 | if self.modality == 'video': 167 | _num_retries = 10 168 | for i_try in range(_num_retries): 169 | record = copy.deepcopy(self.video_list[index]) 170 | directory = os.path.join(self.root_path, record.path) 171 | video_list = self._decord_decode(directory) 172 | # video_list = self._decord_pyav(directory) 173 | if video_list is None: 174 | print("Failed to decode video idx {} from {}; trial {}".format( 175 | index, directory, i_try) 176 | ) 177 | index = random.randint(0, len(self.video_list)) 178 | continue 179 | break 180 | else: 181 | record = self.video_list[index] 182 | video_list = os.listdir(os.path.join(self.root_path, record.path)) 183 | 184 | if not self.test_mode: # train/val 185 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list) 186 | else: # test 187 | segment_indices = self._get_test_indices(video_list) 188 | 189 | return self.get(record, video_list, segment_indices) 190 | 191 | 192 | def _load_image(self, directory, idx): 193 | if self.modality == 'RGB': 194 | try: 195 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 196 | except Exception: 197 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 198 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 199 | 200 | 201 | def get(self, record, video_list, indices): 202 | images = list() 203 | for seg_ind in indices: 204 | p = int(seg_ind) 205 | if self.modality == 'video': 206 | seg_imgs = [Image.fromarray(video_list[p-1].asnumpy()).convert('RGB')] 207 | else: 208 | seg_imgs = self._load_image(record.path,p) 209 | images.extend(seg_imgs) 210 | if p < len(video_list): 211 | p += 1 212 | process_data, record_label = self.transform((images,record.label)) 213 | return process_data, record_label 214 | 215 | def __len__(self): 216 | return len(self.video_list) 217 | -------------------------------------------------------------------------------- /lists/anet/anet1.3_label2name.json: -------------------------------------------------------------------------------- 1 | {"0": "Applying sunscreen", "1": "Archery", "2": "Arm wrestling", "3": "Assembling bicycle", "4": "BMX", "5": "Baking cookies", "6": "Ballet", "7": "Bathing dog", "8": "Baton twirling", "9": "Beach soccer", "10": "Beer pong", "11": "Belly dance", "12": "Blow-drying hair", "13": "Blowing leaves", "14": "Braiding hair", "15": "Breakdancing", "16": "Brushing hair", "17": "Brushing teeth", "18": "Building sandcastles", "19": "Bullfighting", "20": "Bungee jumping", "21": "Calf roping", "22": "Camel ride", "23": "Canoeing", "24": "Capoeira", "25": "Carving jack-o-lanterns", "26": "Changing car wheel", "27": "Cheerleading", "28": "Chopping wood", "29": "Clean and jerk", "30": "Cleaning shoes", "31": "Cleaning sink", "32": "Cleaning windows", "33": "Clipping cat claws", "34": "Cricket", "35": "Croquet", "36": "Cumbia", "37": "Curling", "38": "Cutting the grass", "39": "Decorating the Christmas tree", "40": "Disc dog", "41": "Discus throw", "42": "Dodgeball", "43": "Doing a powerbomb", "44": "Doing crunches", "45": "Doing fencing", "46": "Doing karate", "47": "Doing kickboxing", "48": "Doing motocross", "49": "Doing nails", "50": "Doing step aerobics", "51": "Drinking beer", "52": "Drinking coffee", "53": "Drum corps", "54": "Elliptical trainer", "55": "Fixing bicycle", "56": "Fixing the roof", "57": "Fun sliding down", "58": "Futsal", "59": "Gargling mouthwash", "60": "Getting a haircut", "61": "Getting a piercing", "62": "Getting a tattoo", "63": "Grooming dog", "64": "Grooming horse", "65": "Hammer throw", "66": "Hand car wash", "67": "Hand washing clothes", "68": "Hanging wallpaper", "69": "Having an ice cream", "70": "High jump", "71": "Hitting a pinata", "72": "Hopscotch", "73": "Horseback riding", "74": "Hula hoop", "75": "Hurling", "76": "Ice fishing", "77": "Installing carpet", "78": "Ironing clothes", "79": "Javelin throw", "80": "Kayaking", "81": "Kite flying", "82": "Kneeling", "83": "Knitting", "84": "Laying tile", "85": "Layup drill in basketball", "86": "Long jump", "87": "Longboarding", "88": "Making a cake", "89": "Making a lemonade", "90": "Making a sandwich", "91": "Making an omelette", "92": "Mixing drinks", "93": "Mooping floor", "94": "Mowing the lawn", "95": "Paintball", "96": "Painting", "97": "Painting fence", "98": "Painting furniture", "99": "Peeling potatoes", "100": "Ping-pong", "101": "Plastering", "102": "Plataform diving", "103": "Playing accordion", "104": "Playing badminton", "105": "Playing bagpipes", "106": "Playing beach volleyball", "107": "Playing blackjack", "108": "Playing congas", "109": "Playing drums", "110": "Playing field hockey", "111": "Playing flauta", "112": "Playing guitarra", "113": "Playing harmonica", "114": "Playing ice hockey", "115": "Playing kickball", "116": "Playing lacrosse", "117": "Playing piano", "118": "Playing polo", "119": "Playing pool", "120": "Playing racquetball", "121": "Playing rubik cube", "122": "Playing saxophone", "123": "Playing squash", "124": "Playing ten pins", "125": "Playing violin", "126": "Playing water polo", "127": "Pole vault", "128": "Polishing forniture", "129": "Polishing shoes", "130": "Powerbocking", "131": "Preparing pasta", "132": "Preparing salad", "133": "Putting in contact lenses", "134": "Putting on makeup", "135": "Putting on shoes", "136": "Rafting", "137": "Raking leaves", "138": "Removing curlers", "139": "Removing ice from car", "140": "Riding bumper cars", "141": "River tubing", "142": "Rock climbing", "143": "Rock-paper-scissors", "144": "Rollerblading", "145": "Roof shingle removal", "146": "Rope skipping", "147": "Running a marathon", "148": "Sailing", "149": "Scuba diving", "150": "Sharpening knives", "151": "Shaving", "152": "Shaving legs", "153": "Shot put", "154": "Shoveling snow", "155": "Shuffleboard", "156": "Skateboarding", "157": "Skiing", "158": "Slacklining", "159": "Smoking a cigarette", "160": "Smoking hookah", "161": "Snatch", "162": "Snow tubing", "163": "Snowboarding", "164": "Spinning", "165": "Spread mulch", "166": "Springboard diving", "167": "Starting a campfire", "168": "Sumo", "169": "Surfing", "170": "Swimming", "171": "Swinging at the playground", "172": "Table soccer", "173": "Tai chi", "174": "Tango", "175": "Tennis serve with ball bouncing", "176": "Throwing darts", "177": "Trimming branches or hedges", "178": "Triple jump", "179": "Tug of war", "180": "Tumbling", "181": "Using parallel bars", "182": "Using the balance beam", "183": "Using the monkey bar", "184": "Using the pommel horse", "185": "Using the rowing machine", "186": "Using uneven bars", "187": "Vacuuming floor", "188": "Volleyball", "189": "Wakeboarding", "190": "Walking the dog", "191": "Washing dishes", "192": "Washing face", "193": "Washing hands", "194": "Waterskiing", "195": "Waxing skis", "196": "Welding", "197": "Windsurfing", "198": "Wrapping presents", "199": "Zumba"} -------------------------------------------------------------------------------- /lists/anet1.3_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,Applying sunscreen 3 | 1,Archery 4 | 2,Arm wrestling 5 | 3,Assembling bicycle 6 | 4,BMX 7 | 5,Baking cookies 8 | 6,Ballet 9 | 7,Bathing dog 10 | 8,Baton twirling 11 | 9,Beach soccer 12 | 10,Beer pong 13 | 11,Belly dance 14 | 12,Blow-drying hair 15 | 13,Blowing leaves 16 | 14,Braiding hair 17 | 15,Breakdancing 18 | 16,Brushing hair 19 | 17,Brushing teeth 20 | 18,Building sandcastles 21 | 19,Bullfighting 22 | 20,Bungee jumping 23 | 21,Calf roping 24 | 22,Camel ride 25 | 23,Canoeing 26 | 24,Capoeira 27 | 25,Carving jack-o-lanterns 28 | 26,Changing car wheel 29 | 27,Cheerleading 30 | 28,Chopping wood 31 | 29,Clean and jerk 32 | 30,Cleaning shoes 33 | 31,Cleaning sink 34 | 32,Cleaning windows 35 | 33,Clipping cat claws 36 | 34,Cricket 37 | 35,Croquet 38 | 36,Cumbia 39 | 37,Curling 40 | 38,Cutting the grass 41 | 39,Decorating the Christmas tree 42 | 40,Disc dog 43 | 41,Discus throw 44 | 42,Dodgeball 45 | 43,Doing a powerbomb 46 | 44,Doing crunches 47 | 45,Doing fencing 48 | 46,Doing karate 49 | 47,Doing kickboxing 50 | 48,Doing motocross 51 | 49,Doing nails 52 | 50,Doing step aerobics 53 | 51,Drinking beer 54 | 52,Drinking coffee 55 | 53,Drum corps 56 | 54,Elliptical trainer 57 | 55,Fixing bicycle 58 | 56,Fixing the roof 59 | 57,Fun sliding down 60 | 58,Futsal 61 | 59,Gargling mouthwash 62 | 60,Getting a haircut 63 | 61,Getting a piercing 64 | 62,Getting a tattoo 65 | 63,Grooming dog 66 | 64,Grooming horse 67 | 65,Hammer throw 68 | 66,Hand car wash 69 | 67,Hand washing clothes 70 | 68,Hanging wallpaper 71 | 69,Having an ice cream 72 | 70,High jump 73 | 71,Hitting a pinata 74 | 72,Hopscotch 75 | 73,Horseback riding 76 | 74,Hula hoop 77 | 75,Hurling 78 | 76,Ice fishing 79 | 77,Installing carpet 80 | 78,Ironing clothes 81 | 79,Javelin throw 82 | 80,Kayaking 83 | 81,Kite flying 84 | 82,Kneeling 85 | 83,Knitting 86 | 84,Laying tile 87 | 85,Layup drill in basketball 88 | 86,Long jump 89 | 87,Longboarding 90 | 88,Making a cake 91 | 89,Making a lemonade 92 | 90,Making a sandwich 93 | 91,Making an omelette 94 | 92,Mixing drinks 95 | 93,Mooping floor 96 | 94,Mowing the lawn 97 | 95,Paintball 98 | 96,Painting 99 | 97,Painting fence 100 | 98,Painting furniture 101 | 99,Peeling potatoes 102 | 100,Ping-pong 103 | 101,Plastering 104 | 102,Plataform diving 105 | 103,Playing accordion 106 | 104,Playing badminton 107 | 105,Playing bagpipes 108 | 106,Playing beach volleyball 109 | 107,Playing blackjack 110 | 108,Playing congas 111 | 109,Playing drums 112 | 110,Playing field hockey 113 | 111,Playing flauta 114 | 112,Playing guitarra 115 | 113,Playing harmonica 116 | 114,Playing ice hockey 117 | 115,Playing kickball 118 | 116,Playing lacrosse 119 | 117,Playing piano 120 | 118,Playing polo 121 | 119,Playing pool 122 | 120,Playing racquetball 123 | 121,Playing rubik cube 124 | 122,Playing saxophone 125 | 123,Playing squash 126 | 124,Playing ten pins 127 | 125,Playing violin 128 | 126,Playing water polo 129 | 127,Pole vault 130 | 128,Polishing forniture 131 | 129,Polishing shoes 132 | 130,Powerbocking 133 | 131,Preparing pasta 134 | 132,Preparing salad 135 | 133,Putting in contact lenses 136 | 134,Putting on makeup 137 | 135,Putting on shoes 138 | 136,Rafting 139 | 137,Raking leaves 140 | 138,Removing curlers 141 | 139,Removing ice from car 142 | 140,Riding bumper cars 143 | 141,River tubing 144 | 142,Rock climbing 145 | 143,Rock-paper-scissors 146 | 144,Rollerblading 147 | 145,Roof shingle removal 148 | 146,Rope skipping 149 | 147,Running a marathon 150 | 148,Sailing 151 | 149,Scuba diving 152 | 150,Sharpening knives 153 | 151,Shaving 154 | 152,Shaving legs 155 | 153,Shot put 156 | 154,Shoveling snow 157 | 155,Shuffleboard 158 | 156,Skateboarding 159 | 157,Skiing 160 | 158,Slacklining 161 | 159,Smoking a cigarette 162 | 160,Smoking hookah 163 | 161,Snatch 164 | 162,Snow tubing 165 | 163,Snowboarding 166 | 164,Spinning 167 | 165,Spread mulch 168 | 166,Springboard diving 169 | 167,Starting a campfire 170 | 168,Sumo 171 | 169,Surfing 172 | 170,Swimming 173 | 171,Swinging at the playground 174 | 172,Table soccer 175 | 173,Tai chi 176 | 174,Tango 177 | 175,Tennis serve with ball bouncing 178 | 176,Throwing darts 179 | 177,Trimming branches or hedges 180 | 178,Triple jump 181 | 179,Tug of war 182 | 180,Tumbling 183 | 181,Using parallel bars 184 | 182,Using the balance beam 185 | 183,Using the monkey bar 186 | 184,Using the pommel horse 187 | 185,Using the rowing machine 188 | 186,Using uneven bars 189 | 187,Vacuuming floor 190 | 188,Volleyball 191 | 189,Wakeboarding 192 | 190,Walking the dog 193 | 191,Washing dishes 194 | 192,Washing face 195 | 193,Washing hands 196 | 194,Waterskiing 197 | 195,Waxing skis 198 | 196,Welding 199 | 197,Windsurfing 200 | 198,Wrapping presents 201 | 199,Zumba -------------------------------------------------------------------------------- /lists/hmdb51_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,brush hair 3 | 1,cartwheel 4 | 2,catch 5 | 3,chew 6 | 4,clap 7 | 5,climb 8 | 6,climb stairs 9 | 7,dive 10 | 8,draw sword 11 | 9,dribble 12 | 10,drink 13 | 11,eat 14 | 12,fall floor 15 | 13,fencing 16 | 14,flic flac 17 | 15,golf 18 | 16,handstand 19 | 17,hit 20 | 18,hug 21 | 19,jump 22 | 20,kick 23 | 21,kick ball 24 | 22,kiss 25 | 23,laugh 26 | 24,pick 27 | 25,pour 28 | 26,pullup 29 | 27,punch 30 | 28,push 31 | 29,pushup 32 | 30,ride bike 33 | 31,ride horse 34 | 32,run 35 | 33,shake hands 36 | 34,shoot ball 37 | 35,shoot bow 38 | 36,shoot gun 39 | 37,sit 40 | 38,situp 41 | 39,smile 42 | 40,smoke 43 | 41,somersault 44 | 42,stand 45 | 43,swing baseball 46 | 44,sword 47 | 45,sword exercise 48 | 46,talk 49 | 47,throw 50 | 48,turn 51 | 49,walk 52 | 50,wave 53 | -------------------------------------------------------------------------------- /lists/k600/k160_labels_split1.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,acting in play 3 | 1,adjusting glasses 4 | 2,arguing 5 | 3,attending conference 6 | 4,backflip (human) 7 | 5,base jumping 8 | 6,bathing dog 9 | 7,battle rope training 10 | 8,blowdrying hair 11 | 9,blowing bubble gum 12 | 10,bodysurfing 13 | 11,bottling 14 | 12,breaking boards 15 | 13,breathing fire 16 | 14,building sandcastle 17 | 15,bull fighting 18 | 16,bulldozing 19 | 17,burping 20 | 18,calligraphy 21 | 19,capsizing 22 | 20,casting fishing line 23 | 21,changing gear in car 24 | 22,chiseling stone 25 | 23,chiseling wood 26 | 24,chopping meat 27 | 25,coloring in 28 | 26,combing hair 29 | 27,cooking scallops 30 | 28,cracking knuckles 31 | 29,crossing eyes 32 | 30,cumbia 33 | 31,curling (sport) 34 | 32,cutting apple 35 | 33,cutting orange 36 | 34,delivering mail 37 | 35,docking boat 38 | 36,doing jigsaw puzzle 39 | 37,drooling 40 | 38,dumpster diving 41 | 39,dyeing eyebrows 42 | 40,embroidering 43 | 41,fencing (sport) 44 | 42,fixing bicycle 45 | 43,flint knapping 46 | 44,fly tying 47 | 45,geocaching 48 | 46,getting a piercing 49 | 47,gold panning 50 | 48,gospel singing in church 51 | 49,head stand 52 | 50,home roasting coffee 53 | 51,hugging baby 54 | 52,ice swimming 55 | 53,ironing hair 56 | 54,jaywalking 57 | 55,jumping bicycle 58 | 56,jumping jacks 59 | 57,karaoke 60 | 58,lawn mower racing 61 | 59,laying concrete 62 | 60,laying stone 63 | 61,lifting hat 64 | 62,lighting fire 65 | 63,lock picking 66 | 64,longboarding 67 | 65,luge 68 | 66,making cheese 69 | 67,making paper aeroplanes 70 | 68,marriage proposal 71 | 69,massaging neck 72 | 70,moon walking 73 | 71,mosh pit dancing 74 | 72,mountain climber (exercise) 75 | 73,mushroom foraging 76 | 74,needle felting 77 | 75,opening wine bottle 78 | 76,packing 79 | 77,passing soccer ball 80 | 78,photobombing 81 | 79,photocopying 82 | 80,pinching 83 | 81,pirouetting 84 | 82,planing wood 85 | 83,playing beer pong 86 | 84,playing blackjack 87 | 85,playing darts 88 | 86,playing field hockey 89 | 87,playing gong 90 | 88,playing hand clapping games 91 | 89,playing laser tag 92 | 90,playing lute 93 | 91,playing maracas 94 | 92,playing marbles 95 | 93,playing ocarina 96 | 94,playing pan pipes 97 | 95,playing pinball 98 | 96,playing polo 99 | 97,playing rubiks cube 100 | 98,playing with trains 101 | 99,poking bellybutton 102 | 100,popping balloons 103 | 101,preparing salad 104 | 102,pushing wheelbarrow 105 | 103,putting in contact lenses 106 | 104,putting on eyeliner 107 | 105,putting on foundation 108 | 106,putting on lipstick 109 | 107,putting on mascara 110 | 108,putting on sari 111 | 109,putting on shoes 112 | 110,raising eyebrows 113 | 111,repairing puncture 114 | 112,riding snow blower 115 | 113,roasting pig 116 | 114,rolling pastry 117 | 115,rope pushdown 118 | 116,sausage making 119 | 117,sawing wood 120 | 118,scrubbing face 121 | 119,separating eggs 122 | 120,sewing 123 | 121,shaping bread dough 124 | 122,shining flashlight 125 | 123,shucking oysters 126 | 124,sipping cup 127 | 125,skiing mono 128 | 126,sleeping 129 | 127,smelling feet 130 | 128,smoking pipe 131 | 129,square dancing 132 | 130,standing on hands 133 | 131,steer roping 134 | 132,sucking lolly 135 | 133,swinging baseball bat 136 | 134,tackling 137 | 135,tagging graffiti 138 | 136,talking on cell phone 139 | 137,tasting wine 140 | 138,threading needle 141 | 139,throwing knife 142 | 140,throwing snowballs 143 | 141,throwing tantrum 144 | 142,tie dying 145 | 143,tiptoeing 146 | 144,trimming shrubs 147 | 145,tying shoe laces 148 | 146,using a microscope 149 | 147,using a power drill 150 | 148,using a sledge hammer 151 | 149,using a wrench 152 | 150,using atm 153 | 151,using puppets 154 | 152,vacuuming floor 155 | 153,visiting the zoo 156 | 154,wading through water 157 | 155,watching tv 158 | 156,waving hand 159 | 157,winking 160 | 158,wood burning (art) 161 | 159,yarn spinning 162 | -------------------------------------------------------------------------------- /lists/k600/k160_labels_split2.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,acting in play 3 | 1,arguing 4 | 2,assembling bicycle 5 | 3,backflip (human) 6 | 4,base jumping 7 | 5,bathing dog 8 | 6,battle rope training 9 | 7,blowing bubble gum 10 | 8,bottling 11 | 9,breathing fire 12 | 10,building lego 13 | 11,building sandcastle 14 | 12,bull fighting 15 | 13,calculating 16 | 14,calligraphy 17 | 15,card stacking 18 | 16,card throwing 19 | 17,carving ice 20 | 18,chewing gum 21 | 19,chiseling stone 22 | 20,chopping meat 23 | 21,chopping vegetables 24 | 22,coloring in 25 | 23,contorting 26 | 24,cooking scallops 27 | 25,cosplaying 28 | 26,cracking back 29 | 27,cracking knuckles 30 | 28,crossing eyes 31 | 29,cumbia 32 | 30,curling (sport) 33 | 31,cutting apple 34 | 32,delivering mail 35 | 33,directing traffic 36 | 34,docking boat 37 | 35,doing jigsaw puzzle 38 | 36,drooling 39 | 37,dumpster diving 40 | 38,dyeing eyebrows 41 | 39,embroidering 42 | 40,falling off bike 43 | 41,fidgeting 44 | 42,fixing bicycle 45 | 43,geocaching 46 | 44,getting a piercing 47 | 45,gold panning 48 | 46,hand washing clothes 49 | 47,head stand 50 | 48,historical reenactment 51 | 49,huddling 52 | 50,ice swimming 53 | 51,installing carpet 54 | 52,ironing hair 55 | 53,jumping bicycle 56 | 54,jumping jacks 57 | 55,land sailing 58 | 56,lawn mower racing 59 | 57,laying stone 60 | 58,laying tiles 61 | 59,leatherworking 62 | 60,licking 63 | 61,lifting hat 64 | 62,lighting fire 65 | 63,lock picking 66 | 64,longboarding 67 | 65,looking at phone 68 | 66,luge 69 | 67,making balloon shapes 70 | 68,making bubbles 71 | 69,marriage proposal 72 | 70,massaging neck 73 | 71,moon walking 74 | 72,mosh pit dancing 75 | 73,mushroom foraging 76 | 74,opening door 77 | 75,opening refrigerator 78 | 76,photobombing 79 | 77,pinching 80 | 78,pirouetting 81 | 79,planing wood 82 | 80,playing beer pong 83 | 81,playing blackjack 84 | 82,playing darts 85 | 83,playing dominoes 86 | 84,playing field hockey 87 | 85,playing gong 88 | 86,playing hand clapping games 89 | 87,playing laser tag 90 | 88,playing lute 91 | 89,playing maracas 92 | 90,playing netball 93 | 91,playing ocarina 94 | 92,playing pan pipes 95 | 93,playing ping pong 96 | 94,playing polo 97 | 95,playing rubiks cube 98 | 96,playing scrabble 99 | 97,playing with trains 100 | 98,poking bellybutton 101 | 99,polishing metal 102 | 100,popping balloons 103 | 101,pouring beer 104 | 102,preparing salad 105 | 103,putting in contact lenses 106 | 104,putting on eyeliner 107 | 105,putting on foundation 108 | 106,putting on mascara 109 | 107,putting on sari 110 | 108,repairing puncture 111 | 109,riding snow blower 112 | 110,roasting marshmallows 113 | 111,roasting pig 114 | 112,rolling pastry 115 | 113,rope pushdown 116 | 114,scrapbooking 117 | 115,scrubbing face 118 | 116,separating eggs 119 | 117,sewing 120 | 118,shaping bread dough 121 | 119,shining flashlight 122 | 120,shopping 123 | 121,shuffling feet 124 | 122,sipping cup 125 | 123,skiing mono 126 | 124,skipping stone 127 | 125,smashing 128 | 126,smelling feet 129 | 127,smoking pipe 130 | 128,square dancing 131 | 129,steer roping 132 | 130,sucking lolly 133 | 131,swimming front crawl 134 | 132,swinging baseball bat 135 | 133,sword swallowing 136 | 134,tagging graffiti 137 | 135,talking on cell phone 138 | 136,tasting wine 139 | 137,threading needle 140 | 138,throwing knife 141 | 139,throwing snowballs 142 | 140,throwing tantrum 143 | 141,throwing water balloon 144 | 142,tie dying 145 | 143,tightrope walking 146 | 144,trimming shrubs 147 | 145,twiddling fingers 148 | 146,tying shoe laces 149 | 147,using a paint roller 150 | 148,using a power drill 151 | 149,using a wrench 152 | 150,using atm 153 | 151,using bagging machine 154 | 152,using circular saw 155 | 153,using inhaler 156 | 154,visiting the zoo 157 | 155,wading through mud 158 | 156,wading through water 159 | 157,waving hand 160 | 158,weaving fabric 161 | 159,winking 162 | -------------------------------------------------------------------------------- /lists/k600/k160_labels_split3.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,adjusting glasses 3 | 1,alligator wrestling 4 | 2,archaeological excavation 5 | 3,arguing 6 | 4,assembling bicycle 7 | 5,attending conference 8 | 6,backflip (human) 9 | 7,base jumping 10 | 8,battle rope training 11 | 9,blowdrying hair 12 | 10,blowing bubble gum 13 | 11,bouncing on bouncy castle 14 | 12,breaking boards 15 | 13,breathing fire 16 | 14,building lego 17 | 15,bull fighting 18 | 16,bulldozing 19 | 17,burping 20 | 18,calculating 21 | 19,calligraphy 22 | 20,capsizing 23 | 21,card stacking 24 | 22,card throwing 25 | 23,carving ice 26 | 24,casting fishing line 27 | 25,chiseling stone 28 | 26,chiseling wood 29 | 27,chopping meat 30 | 28,clam digging 31 | 29,coloring in 32 | 30,combing hair 33 | 31,contorting 34 | 32,cooking scallops 35 | 33,cosplaying 36 | 34,cracking knuckles 37 | 35,crossing eyes 38 | 36,cumbia 39 | 37,curling (sport) 40 | 38,delivering mail 41 | 39,directing traffic 42 | 40,doing jigsaw puzzle 43 | 41,dyeing eyebrows 44 | 42,falling off bike 45 | 43,falling off chair 46 | 44,fencing (sport) 47 | 45,fixing bicycle 48 | 46,flint knapping 49 | 47,fly tying 50 | 48,geocaching 51 | 49,hand washing clothes 52 | 50,head stand 53 | 51,historical reenactment 54 | 52,home roasting coffee 55 | 53,huddling 56 | 54,ice swimming 57 | 55,installing carpet 58 | 56,ironing hair 59 | 57,jaywalking 60 | 58,jumping bicycle 61 | 59,karaoke 62 | 60,land sailing 63 | 61,lawn mower racing 64 | 62,laying concrete 65 | 63,laying stone 66 | 64,laying tiles 67 | 65,leatherworking 68 | 66,licking 69 | 67,lifting hat 70 | 68,lighting fire 71 | 69,lock picking 72 | 70,longboarding 73 | 71,looking at phone 74 | 72,luge 75 | 73,making balloon shapes 76 | 74,making bubbles 77 | 75,making cheese 78 | 76,making paper aeroplanes 79 | 77,marriage proposal 80 | 78,mosh pit dancing 81 | 79,mushroom foraging 82 | 80,needle felting 83 | 81,opening door 84 | 82,opening refrigerator 85 | 83,opening wine bottle 86 | 84,packing 87 | 85,passing soccer ball 88 | 86,photobombing 89 | 87,photocopying 90 | 88,pinching 91 | 89,planing wood 92 | 90,playing blackjack 93 | 91,playing darts 94 | 92,playing field hockey 95 | 93,playing gong 96 | 94,playing laser tag 97 | 95,playing lute 98 | 96,playing maracas 99 | 97,playing ocarina 100 | 98,playing ping pong 101 | 99,playing polo 102 | 100,poking bellybutton 103 | 101,popping balloons 104 | 102,pouring beer 105 | 103,preparing salad 106 | 104,pushing wheelbarrow 107 | 105,putting in contact lenses 108 | 106,putting on eyeliner 109 | 107,putting on lipstick 110 | 108,putting on mascara 111 | 109,putting on sari 112 | 110,putting on shoes 113 | 111,raising eyebrows 114 | 112,repairing puncture 115 | 113,riding snow blower 116 | 114,roasting marshmallows 117 | 115,rolling pastry 118 | 116,rope pushdown 119 | 117,sawing wood 120 | 118,scrubbing face 121 | 119,shaping bread dough 122 | 120,shopping 123 | 121,shuffling feet 124 | 122,sipping cup 125 | 123,skiing mono 126 | 124,skipping stone 127 | 125,sleeping 128 | 126,smashing 129 | 127,square dancing 130 | 128,staring 131 | 129,steer roping 132 | 130,sucking lolly 133 | 131,swimming front crawl 134 | 132,swinging baseball bat 135 | 133,sword swallowing 136 | 134,tasting wine 137 | 135,threading needle 138 | 136,throwing knife 139 | 137,throwing snowballs 140 | 138,throwing tantrum 141 | 139,tie dying 142 | 140,tightrope walking 143 | 141,tiptoeing 144 | 142,trimming shrubs 145 | 143,using a microscope 146 | 144,using a paint roller 147 | 145,using a power drill 148 | 146,using a sledge hammer 149 | 147,using a wrench 150 | 148,using atm 151 | 149,using bagging machine 152 | 150,using inhaler 153 | 151,using puppets 154 | 152,visiting the zoo 155 | 153,wading through water 156 | 154,walking through snow 157 | 155,watching tv 158 | 156,waving hand 159 | 157,weaving fabric 160 | 158,winking 161 | 159,wood burning (art) 162 | -------------------------------------------------------------------------------- /lists/kinetics_400_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,abseiling 3 | 1,air drumming 4 | 2,answering questions 5 | 3,applauding 6 | 4,applying cream 7 | 5,archery 8 | 6,arm wrestling 9 | 7,arranging flowers 10 | 8,assembling computer 11 | 9,auctioning 12 | 10,baby waking up 13 | 11,baking cookies 14 | 12,balloon blowing 15 | 13,bandaging 16 | 14,barbequing 17 | 15,bartending 18 | 16,beatboxing 19 | 17,bee keeping 20 | 18,belly dancing 21 | 19,bench pressing 22 | 20,bending back 23 | 21,bending metal 24 | 22,biking through snow 25 | 23,blasting sand 26 | 24,blowing glass 27 | 25,blowing leaves 28 | 26,blowing nose 29 | 27,blowing out candles 30 | 28,bobsledding 31 | 29,bookbinding 32 | 30,bouncing on trampoline 33 | 31,bowling 34 | 32,braiding hair 35 | 33,breading or breadcrumbing 36 | 34,breakdancing 37 | 35,brush painting 38 | 36,brushing hair 39 | 37,brushing teeth 40 | 38,building cabinet 41 | 39,building shed 42 | 40,bungee jumping 43 | 41,busking 44 | 42,canoeing or kayaking 45 | 43,capoeira 46 | 44,carrying baby 47 | 45,cartwheeling 48 | 46,carving pumpkin 49 | 47,catching fish 50 | 48,catching or throwing baseball 51 | 49,catching or throwing frisbee 52 | 50,catching or throwing softball 53 | 51,celebrating 54 | 52,changing oil 55 | 53,changing wheel 56 | 54,checking tires 57 | 55,cheerleading 58 | 56,chopping wood 59 | 57,clapping 60 | 58,clay pottery making 61 | 59,clean and jerk 62 | 60,cleaning floor 63 | 61,cleaning gutters 64 | 62,cleaning pool 65 | 63,cleaning shoes 66 | 64,cleaning toilet 67 | 65,cleaning windows 68 | 66,climbing a rope 69 | 67,climbing ladder 70 | 68,climbing tree 71 | 69,contact juggling 72 | 70,cooking chicken 73 | 71,cooking egg 74 | 72,cooking on campfire 75 | 73,cooking sausages 76 | 74,counting money 77 | 75,country line dancing 78 | 76,cracking neck 79 | 77,crawling baby 80 | 78,crossing river 81 | 79,crying 82 | 80,curling hair 83 | 81,cutting nails 84 | 82,cutting pineapple 85 | 83,cutting watermelon 86 | 84,dancing ballet 87 | 85,dancing charleston 88 | 86,dancing gangnam style 89 | 87,dancing macarena 90 | 88,deadlifting 91 | 89,decorating the christmas tree 92 | 90,digging 93 | 91,dining 94 | 92,disc golfing 95 | 93,diving cliff 96 | 94,dodgeball 97 | 95,doing aerobics 98 | 96,doing laundry 99 | 97,doing nails 100 | 98,drawing 101 | 99,dribbling basketball 102 | 100,drinking 103 | 101,drinking beer 104 | 102,drinking shots 105 | 103,driving car 106 | 104,driving tractor 107 | 105,drop kicking 108 | 106,drumming fingers 109 | 107,dunking basketball 110 | 108,dying hair 111 | 109,eating burger 112 | 110,eating cake 113 | 111,eating carrots 114 | 112,eating chips 115 | 113,eating doughnuts 116 | 114,eating hotdog 117 | 115,eating ice cream 118 | 116,eating spaghetti 119 | 117,eating watermelon 120 | 118,egg hunting 121 | 119,exercising arm 122 | 120,exercising with an exercise ball 123 | 121,extinguishing fire 124 | 122,faceplanting 125 | 123,feeding birds 126 | 124,feeding fish 127 | 125,feeding goats 128 | 126,filling eyebrows 129 | 127,finger snapping 130 | 128,fixing hair 131 | 129,flipping pancake 132 | 130,flying kite 133 | 131,folding clothes 134 | 132,folding napkins 135 | 133,folding paper 136 | 134,front raises 137 | 135,frying vegetables 138 | 136,garbage collecting 139 | 137,gargling 140 | 138,getting a haircut 141 | 139,getting a tattoo 142 | 140,giving or receiving award 143 | 141,golf chipping 144 | 142,golf driving 145 | 143,golf putting 146 | 144,grinding meat 147 | 145,grooming dog 148 | 146,grooming horse 149 | 147,gymnastics tumbling 150 | 148,hammer throw 151 | 149,headbanging 152 | 150,headbutting 153 | 151,high jump 154 | 152,high kick 155 | 153,hitting baseball 156 | 154,hockey stop 157 | 155,holding snake 158 | 156,hopscotch 159 | 157,hoverboarding 160 | 158,hugging 161 | 159,hula hooping 162 | 160,hurdling 163 | 161,hurling (sport) 164 | 162,ice climbing 165 | 163,ice fishing 166 | 164,ice skating 167 | 165,ironing 168 | 166,javelin throw 169 | 167,jetskiing 170 | 168,jogging 171 | 169,juggling balls 172 | 170,juggling fire 173 | 171,juggling soccer ball 174 | 172,jumping into pool 175 | 173,jumpstyle dancing 176 | 174,kicking field goal 177 | 175,kicking soccer ball 178 | 176,kissing 179 | 177,kitesurfing 180 | 178,knitting 181 | 179,krumping 182 | 180,laughing 183 | 181,laying bricks 184 | 182,long jump 185 | 183,lunge 186 | 184,making a cake 187 | 185,making a sandwich 188 | 186,making bed 189 | 187,making jewelry 190 | 188,making pizza 191 | 189,making snowman 192 | 190,making sushi 193 | 191,making tea 194 | 192,marching 195 | 193,massaging back 196 | 194,massaging feet 197 | 195,massaging legs 198 | 196,massaging person's head 199 | 197,milking cow 200 | 198,mopping floor 201 | 199,motorcycling 202 | 200,moving furniture 203 | 201,mowing lawn 204 | 202,news anchoring 205 | 203,opening bottle 206 | 204,opening present 207 | 205,paragliding 208 | 206,parasailing 209 | 207,parkour 210 | 208,passing American football (in game) 211 | 209,passing American football (not in game) 212 | 210,peeling apples 213 | 211,peeling potatoes 214 | 212,petting animal (not cat) 215 | 213,petting cat 216 | 214,picking fruit 217 | 215,planting trees 218 | 216,plastering 219 | 217,playing accordion 220 | 218,playing badminton 221 | 219,playing bagpipes 222 | 220,playing basketball 223 | 221,playing bass guitar 224 | 222,playing cards 225 | 223,playing cello 226 | 224,playing chess 227 | 225,playing clarinet 228 | 226,playing controller 229 | 227,playing cricket 230 | 228,playing cymbals 231 | 229,playing didgeridoo 232 | 230,playing drums 233 | 231,playing flute 234 | 232,playing guitar 235 | 233,playing harmonica 236 | 234,playing harp 237 | 235,playing ice hockey 238 | 236,playing keyboard 239 | 237,playing kickball 240 | 238,playing monopoly 241 | 239,playing organ 242 | 240,playing paintball 243 | 241,playing piano 244 | 242,playing poker 245 | 243,playing recorder 246 | 244,playing saxophone 247 | 245,playing squash or racquetball 248 | 246,playing tennis 249 | 247,playing trombone 250 | 248,playing trumpet 251 | 249,playing ukulele 252 | 250,playing violin 253 | 251,playing volleyball 254 | 252,playing xylophone 255 | 253,pole vault 256 | 254,presenting weather forecast 257 | 255,pull ups 258 | 256,pumping fist 259 | 257,pumping gas 260 | 258,punching bag 261 | 259,punching person (boxing) 262 | 260,push up 263 | 261,pushing car 264 | 262,pushing cart 265 | 263,pushing wheelchair 266 | 264,reading book 267 | 265,reading newspaper 268 | 266,recording music 269 | 267,riding a bike 270 | 268,riding camel 271 | 269,riding elephant 272 | 270,riding mechanical bull 273 | 271,riding mountain bike 274 | 272,riding mule 275 | 273,riding or walking with horse 276 | 274,riding scooter 277 | 275,riding unicycle 278 | 276,ripping paper 279 | 277,robot dancing 280 | 278,rock climbing 281 | 279,rock scissors paper 282 | 280,roller skating 283 | 281,running on treadmill 284 | 282,sailing 285 | 283,salsa dancing 286 | 284,sanding floor 287 | 285,scrambling eggs 288 | 286,scuba diving 289 | 287,setting table 290 | 288,shaking hands 291 | 289,shaking head 292 | 290,sharpening knives 293 | 291,sharpening pencil 294 | 292,shaving head 295 | 293,shaving legs 296 | 294,shearing sheep 297 | 295,shining shoes 298 | 296,shooting basketball 299 | 297,shooting goal (soccer) 300 | 298,shot put 301 | 299,shoveling snow 302 | 300,shredding paper 303 | 301,shuffling cards 304 | 302,side kick 305 | 303,sign language interpreting 306 | 304,singing 307 | 305,situp 308 | 306,skateboarding 309 | 307,ski jumping 310 | 308,skiing (not slalom or crosscountry) 311 | 309,skiing crosscountry 312 | 310,skiing slalom 313 | 311,skipping rope 314 | 312,skydiving 315 | 313,slacklining 316 | 314,slapping 317 | 315,sled dog racing 318 | 316,smoking 319 | 317,smoking hookah 320 | 318,snatch weight lifting 321 | 319,sneezing 322 | 320,sniffing 323 | 321,snorkeling 324 | 322,snowboarding 325 | 323,snowkiting 326 | 324,snowmobiling 327 | 325,somersaulting 328 | 326,spinning poi 329 | 327,spray painting 330 | 328,spraying 331 | 329,springboard diving 332 | 330,squat 333 | 331,sticking tongue out 334 | 332,stomping grapes 335 | 333,stretching arm 336 | 334,stretching leg 337 | 335,strumming guitar 338 | 336,surfing crowd 339 | 337,surfing water 340 | 338,sweeping floor 341 | 339,swimming backstroke 342 | 340,swimming breast stroke 343 | 341,swimming butterfly stroke 344 | 342,swing dancing 345 | 343,swinging legs 346 | 344,swinging on something 347 | 345,sword fighting 348 | 346,tai chi 349 | 347,taking a shower 350 | 348,tango dancing 351 | 349,tap dancing 352 | 350,tapping guitar 353 | 351,tapping pen 354 | 352,tasting beer 355 | 353,tasting food 356 | 354,testifying 357 | 355,texting 358 | 356,throwing axe 359 | 357,throwing ball 360 | 358,throwing discus 361 | 359,tickling 362 | 360,tobogganing 363 | 361,tossing coin 364 | 362,tossing salad 365 | 363,training dog 366 | 364,trapezing 367 | 365,trimming or shaving beard 368 | 366,trimming trees 369 | 367,triple jump 370 | 368,tying bow tie 371 | 369,tying knot (not on a tie) 372 | 370,tying tie 373 | 371,unboxing 374 | 372,unloading truck 375 | 373,using computer 376 | 374,using remote controller (not gaming) 377 | 375,using segway 378 | 376,vault 379 | 377,waiting in line 380 | 378,walking the dog 381 | 379,washing dishes 382 | 380,washing feet 383 | 381,washing hair 384 | 382,washing hands 385 | 383,water skiing 386 | 384,water sliding 387 | 385,watering plants 388 | 386,waxing back 389 | 387,waxing chest 390 | 388,waxing eyebrows 391 | 389,waxing legs 392 | 390,weaving basket 393 | 391,welding 394 | 392,whistling 395 | 393,windsurfing 396 | 394,wrapping present 397 | 395,wrestling 398 | 396,writing 399 | 397,yawning 400 | 398,yoga 401 | 399,zumba 402 | -------------------------------------------------------------------------------- /lists/ucf_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,ApplyEyeMakeup 3 | 1,ApplyLipstick 4 | 2,Archery 5 | 3,BabyCrawling 6 | 4,BalanceBeam 7 | 5,BandMarching 8 | 6,BaseballPitch 9 | 7,Basketball 10 | 8,BasketballDunk 11 | 9,BenchPress 12 | 10,Biking 13 | 11,Billiards 14 | 12,BlowDryHair 15 | 13,BlowingCandles 16 | 14,BodyWeightSquats 17 | 15,Bowling 18 | 16,BoxingPunchingBag 19 | 17,BoxingSpeedBag 20 | 18,BreastStroke 21 | 19,BrushingTeeth 22 | 20,CleanAndJerk 23 | 21,CliffDiving 24 | 22,CricketBowling 25 | 23,CricketShot 26 | 24,CuttingInKitchen 27 | 25,Diving 28 | 26,Drumming 29 | 27,Fencing 30 | 28,FieldHockeyPenalty 31 | 29,FloorGymnastics 32 | 30,FrisbeeCatch 33 | 31,FrontCrawl 34 | 32,GolfSwing 35 | 33,Haircut 36 | 34,Hammering 37 | 35,HammerThrow 38 | 36,HandstandPushups 39 | 37,HandstandWalking 40 | 38,HeadMassage 41 | 39,HighJump 42 | 40,HorseRace 43 | 41,HorseRiding 44 | 42,HulaHoop 45 | 43,IceDancing 46 | 44,JavelinThrow 47 | 45,JugglingBalls 48 | 46,JumpingJack 49 | 47,JumpRope 50 | 48,Kayaking 51 | 49,Knitting 52 | 50,LongJump 53 | 51,Lunges 54 | 52,MilitaryParade 55 | 53,Mixing 56 | 54,MoppingFloor 57 | 55,Nunchucks 58 | 56,ParallelBars 59 | 57,PizzaTossing 60 | 58,PlayingCello 61 | 59,PlayingDaf 62 | 60,PlayingDhol 63 | 61,PlayingFlute 64 | 62,PlayingGuitar 65 | 63,PlayingPiano 66 | 64,PlayingSitar 67 | 65,PlayingTabla 68 | 66,PlayingViolin 69 | 67,PoleVault 70 | 68,PommelHorse 71 | 69,PullUps 72 | 70,Punch 73 | 71,PushUps 74 | 72,Rafting 75 | 73,RockClimbingIndoor 76 | 74,RopeClimbing 77 | 75,Rowing 78 | 76,SalsaSpin 79 | 77,ShavingBeard 80 | 78,Shotput 81 | 79,SkateBoarding 82 | 80,Skiing 83 | 81,Skijet 84 | 82,SkyDiving 85 | 83,SoccerJuggling 86 | 84,SoccerPenalty 87 | 85,StillRings 88 | 86,SumoWrestling 89 | 87,Surfing 90 | 88,Swing 91 | 89,TableTennisShot 92 | 90,TaiChi 93 | 91,TennisSwing 94 | 92,ThrowDiscus 95 | 93,TrampolineJumping 96 | 94,Typing 97 | 95,UnevenBars 98 | 96,VolleyballSpiking 99 | 97,WalkingWithDog 100 | 98,WallPushups 101 | 99,WritingOnBoard 102 | 100,YoYo 103 | -------------------------------------------------------------------------------- /modules/coop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch.cuda.amp import GradScaler, autocast 7 | 8 | 9 | from clip import clip 10 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | 12 | _tokenizer = _Tokenizer() 13 | 14 | 15 | 16 | 17 | class TextEncoder(nn.Module): 18 | def __init__(self, clip_model): 19 | super().__init__() 20 | self.transformer = clip_model.transformer 21 | self.positional_embedding = clip_model.positional_embedding 22 | self.ln_final = clip_model.ln_final 23 | self.text_projection = clip_model.text_projection 24 | self.dtype = clip_model.dtype 25 | 26 | def forward(self, prompts, tokenized_prompts): 27 | x = prompts + self.positional_embedding.type(self.dtype) 28 | x = x.permute(1, 0, 2) # NLD -> LND 29 | x = self.transformer(x) 30 | x = x.permute(1, 0, 2) # LND -> NLD 31 | x = self.ln_final(x).type(self.dtype) 32 | 33 | # x.shape = [batch_size, n_ctx, transformer.width] 34 | # take features from the eot embedding (eot_token is the highest number in each sequence) 35 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 36 | 37 | return x 38 | 39 | 40 | class PromptLearner(nn.Module): 41 | def __init__(self, cfg, classnames, clip_model): 42 | super().__init__() 43 | n_cls = len(classnames) 44 | n_ctx = cfg.COOP.N_CTX 45 | ctx_init = cfg.COOP.CTX_INIT 46 | dtype = clip_model.dtype 47 | ctx_dim = clip_model.ln_final.weight.shape[0] 48 | clip_imsize = clip_model.visual.input_resolution 49 | cfg_imsize = cfg.data.input_size 50 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 51 | 52 | if ctx_init: 53 | # use given words to initialize context vectors 54 | ctx_init = ctx_init.replace("_", " ") 55 | n_ctx = len(ctx_init.split(" ")) 56 | prompt = clip.tokenize(ctx_init) 57 | with torch.no_grad(): 58 | embedding = clip_model.token_embedding(prompt).type(dtype) 59 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 60 | prompt_prefix = ctx_init 61 | 62 | else: 63 | # random initialization 64 | if cfg.COOP.CSC: 65 | print("Initializing class-specific contexts") 66 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 67 | else: 68 | print("Initializing a generic context") 69 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 70 | nn.init.normal_(ctx_vectors, std=0.02) 71 | prompt_prefix = " ".join(["X"] * n_ctx) 72 | 73 | print(f'Initial context: "{prompt_prefix}"') 74 | print(f"Number of context words (tokens): {n_ctx}") 75 | 76 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 77 | 78 | classnames = [name.replace("_", " ") for name in classnames] 79 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 80 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 81 | 82 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 83 | with torch.no_grad(): 84 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 85 | 86 | # These token vectors will be saved when in save_model(), 87 | # but they should be ignored in load_model() as we want to use 88 | # those computed using the current class names 89 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 90 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 91 | 92 | self.n_cls = n_cls 93 | self.n_ctx = n_ctx 94 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 95 | self.name_lens = name_lens 96 | self.class_token_position = cfg.COOP.CLASS_TOKEN_POSITION 97 | 98 | def forward(self): 99 | ctx = self.ctx 100 | if ctx.dim() == 2: 101 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 102 | 103 | prefix = self.token_prefix 104 | suffix = self.token_suffix 105 | 106 | if self.class_token_position == "end": 107 | prompts = torch.cat( 108 | [ 109 | prefix, # (n_cls, 1, dim) 110 | ctx, # (n_cls, n_ctx, dim) 111 | suffix, # (n_cls, *, dim) 112 | ], 113 | dim=1, 114 | ) 115 | 116 | elif self.class_token_position == "middle": 117 | half_n_ctx = self.n_ctx // 2 118 | prompts = [] 119 | for i in range(self.n_cls): 120 | name_len = self.name_lens[i] 121 | prefix_i = prefix[i : i + 1, :, :] 122 | class_i = suffix[i : i + 1, :name_len, :] 123 | suffix_i = suffix[i : i + 1, name_len:, :] 124 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 125 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 126 | prompt = torch.cat( 127 | [ 128 | prefix_i, # (1, 1, dim) 129 | ctx_i_half1, # (1, n_ctx//2, dim) 130 | class_i, # (1, name_len, dim) 131 | ctx_i_half2, # (1, n_ctx//2, dim) 132 | suffix_i, # (1, *, dim) 133 | ], 134 | dim=1, 135 | ) 136 | prompts.append(prompt) 137 | prompts = torch.cat(prompts, dim=0) 138 | 139 | elif self.class_token_position == "front": 140 | prompts = [] 141 | for i in range(self.n_cls): 142 | name_len = self.name_lens[i] 143 | prefix_i = prefix[i : i + 1, :, :] 144 | class_i = suffix[i : i + 1, :name_len, :] 145 | suffix_i = suffix[i : i + 1, name_len:, :] 146 | ctx_i = ctx[i : i + 1, :, :] 147 | prompt = torch.cat( 148 | [ 149 | prefix_i, # (1, 1, dim) 150 | class_i, # (1, name_len, dim) 151 | ctx_i, # (1, n_ctx, dim) 152 | suffix_i, # (1, *, dim) 153 | ], 154 | dim=1, 155 | ) 156 | prompts.append(prompt) 157 | prompts = torch.cat(prompts, dim=0) 158 | 159 | else: 160 | raise ValueError 161 | 162 | return prompts 163 | 164 | 165 | class CoopCLIP(nn.Module): 166 | def __init__(self, cfg, classnames, clip_model, fusion_model, n_seg): 167 | super().__init__() 168 | self.clip_model = clip_model 169 | self.fusion_model = fusion_model 170 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 171 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 172 | self.text_encoder = TextEncoder(clip_model) 173 | self.logit_scale = clip_model.logit_scale 174 | self.dtype = clip_model.dtype 175 | self.n_seg = n_seg 176 | self.cfg = cfg 177 | 178 | def encode_image(self, image): 179 | bt = image.size(0) 180 | b = bt // self.n_seg 181 | 182 | image_emb = self.clip_model.encode_image(image).view(b, self.n_seg, -1) 183 | image_emb = self.fusion_model(image_emb) 184 | return image_emb 185 | 186 | 187 | def forward(self, image, class_id, train=False): 188 | if train: 189 | return self.forward_train(image, class_id) 190 | else: 191 | return self.forward_test(image, class_id) 192 | 193 | def forward_train(self, image, class_id): 194 | 195 | image_features = self.encode_image(image.type(self.dtype)) # B dim 196 | 197 | prompts = self.prompt_learner() # [400 77 512] 198 | tokenized_prompts = self.tokenized_prompts # [400 77] 199 | 200 | if self.cfg.solver.loss_type == 'NCE': 201 | prompts = prompts[class_id] # [bs 77 512] 202 | tokenized_prompts = tokenized_prompts[class_id] # [bs 77] 203 | text_features = self.text_encoder(prompts, tokenized_prompts) # bs Dim 204 | return image_features, text_features, self.logit_scale.exp() 205 | elif self.cfg.solver.loss_type == 'CE': 206 | text_features = self.text_encoder(prompts, tokenized_prompts) # 400 Dim 207 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 208 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 209 | 210 | logit_scale = self.logit_scale.exp() 211 | logits = logit_scale * image_features @ text_features.t() # B 400 212 | return logits 213 | elif self.cfg.solver.loss_type == 'NCE,CE': 214 | bs_image_features = image_features 215 | text_features = self.text_encoder(prompts, tokenized_prompts) # 400 Dim 216 | bs_text_features = text_features[class_id] # bs dim 217 | 218 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 219 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 220 | 221 | logit_scale = self.logit_scale.exp() 222 | logits = logit_scale * image_features @ text_features.t() # B 400 223 | return bs_image_features, bs_text_features, logit_scale, logits 224 | 225 | def forward_test(self, image, class_id): 226 | image_features = self.encode_image(image.type(self.dtype)) # B dim 227 | 228 | prompts = self.prompt_learner() # [400 77 512] 229 | tokenized_prompts = self.tokenized_prompts # [400 77] 230 | 231 | text_features = self.text_encoder(prompts, tokenized_prompts) # 400 Dim 232 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 233 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 234 | 235 | logit_scale = self.logit_scale.exp() 236 | logits = logit_scale * image_features @ text_features.t() # B 400 237 | return logits -------------------------------------------------------------------------------- /modules/temporal_modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from clip.model import VisualTransformer, ModifiedResNet 5 | import numpy as np 6 | 7 | 8 | class TemporalShift(nn.Module): 9 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): 10 | super(TemporalShift, self).__init__() 11 | self.net = net 12 | self.n_segment = n_segment 13 | self.fold_div = n_div 14 | self.inplace = inplace 15 | if inplace: 16 | print('=> Using in-place shift...') 17 | print('=> Using fold div: {}'.format(self.fold_div)) 18 | 19 | def forward(self, x): 20 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) 21 | x = self.net(x) 22 | return x 23 | 24 | @staticmethod 25 | def shift(x, n_segment, fold_div=3, inplace=False): 26 | nt, c, h, w = x.size() 27 | n_batch = nt // n_segment 28 | x = x.view(n_batch, n_segment, c, h, w) 29 | 30 | fold = c // fold_div 31 | if inplace: 32 | # Due to some out of order error when performing parallel computing. 33 | # May need to write a CUDA kernel. 34 | raise NotImplementedError 35 | # out = InplaceShift.apply(x, fold) 36 | else: 37 | out = torch.zeros_like(x) 38 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 39 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 40 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 41 | 42 | return out.view(nt, c, h, w) 43 | 44 | 45 | class TemporalShift_VIT(nn.Module): 46 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): 47 | super(TemporalShift_VIT, self).__init__() 48 | self.net = net 49 | self.n_segment = n_segment 50 | self.fold_div = n_div 51 | self.inplace = inplace 52 | if inplace: 53 | print('=> Using in-place shift...') 54 | print('=> Using fold div: {}'.format(self.fold_div)) 55 | 56 | def forward(self, x): 57 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) 58 | x = self.net(x) 59 | return x 60 | 61 | @staticmethod 62 | def shift(x, n_segment, fold_div=3, inplace=False): 63 | hw, nt, c = x.size() 64 | cls_ = x[0,:,:].unsqueeze(0) 65 | x = x[1:,:,:] 66 | # print(cls_.size()) 67 | x = x.permute(1,2,0) # nt,c,hw 68 | n_batch = nt // n_segment 69 | h = int(np.sqrt(hw-1)) 70 | w = h 71 | x = x.contiguous().view(n_batch, n_segment, c, h, w) 72 | 73 | fold = c // fold_div 74 | if inplace: 75 | # Due to some out of order error when performing parallel computing. 76 | # May need to write a CUDA kernel. 77 | raise NotImplementedError 78 | # out = InplaceShift.apply(x, fold) 79 | else: 80 | out = torch.zeros_like(x) 81 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left 82 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right 83 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift 84 | out = out.contiguous().view(nt, c, h*w) 85 | out = out.permute(2,0,1) #hw, nt, c 86 | out = torch.cat((cls_,out),dim=0) 87 | # print(out.size()) 88 | return out 89 | 90 | 91 | 92 | class TokenShift(nn.Module): 93 | def __init__(self, n_segment=3, n_div=4): 94 | super(TokenShift, self).__init__() 95 | self.n_segment = n_segment 96 | self.fold_div = n_div 97 | 98 | def forward(self, x): 99 | # n bt c 100 | n, bt, c = x.size() 101 | b = bt // self.n_segment 102 | x = x.permute(1, 0, 2).contiguous().view(b, self.n_segment, n, c) 103 | 104 | fold = c // self.fold_div 105 | out = torch.zeros_like(x) 106 | out[:, :-1, 0, :fold] = x[:, 1:, 0, :fold] # shift left 107 | out[:, 1:, 0, fold:2*fold] = x[:,:-1:, 0, fold:2*fold] # shift right 108 | 109 | out[:, :, 1:, :2*fold] = x[:, :, 1:, :2*fold] # not shift 110 | out[:, :, :, 2*fold:] = x[:, :, :, 2*fold:] # not shift 111 | 112 | out = out.view(bt, n, c).permute(1, 0, 2).contiguous() 113 | 114 | return out 115 | 116 | def make_tokenshift(net, n_segment, n_div=4, locations_list=[]): 117 | for idx, block in enumerate(net.transformer.resblocks): 118 | if idx in locations_list: 119 | net.transformer.resblocks[idx].control_point = TokenShift( 120 | n_segment=n_segment, 121 | n_div=n_div, 122 | ) 123 | 124 | 125 | class TokenT1D(nn.Module): 126 | def __init__(self, in_channels, n_segment=3, n_div=4, mode='shift'): 127 | super(TokenT1D, self).__init__() 128 | self.input_channels = in_channels 129 | self.n_segment = n_segment 130 | self.fold_div = n_div 131 | self.fold = self.input_channels // self.fold_div 132 | self.conv = nn.Conv1d(self.fold_div*self.fold, self.fold_div*self.fold, 133 | kernel_size=3, padding=1, groups=self.fold_div*self.fold, 134 | bias=False) 135 | 136 | if mode == 'shift': 137 | self.conv.weight.requires_grad = True 138 | self.conv.weight.data.zero_() 139 | self.conv.weight.data[:self.fold, 0, 2] = 1 # shift left 140 | self.conv.weight.data[self.fold: 2 * self.fold, 0, 0] = 1 # shift right 141 | if 2*self.fold < self.input_channels: 142 | self.conv.weight.data[2 * self.fold:, 0, 1] = 1 # fixed 143 | elif mode == 'fixed': 144 | self.conv.weight.requires_grad = True 145 | self.conv.weight.data.zero_() 146 | self.conv.weight.data[:, 0, 1] = 1 # fixed 147 | elif mode == 'norm': 148 | self.conv.weight.requires_grad = True 149 | 150 | def forward(self, x): 151 | # n bt c 152 | n, bt, c = x.size() 153 | b = bt // self.n_segment 154 | x = x.permute(1, 0, 2).contiguous().view(b, self.n_segment, n, c) 155 | x = x.permute(0, 2, 3, 1).contiguous() # b, n, c, t 156 | out = torch.zeros_like(x) 157 | out[:, 0] = self.conv(x[:, 0]) 158 | out[:, 1:] = x[:, 1:] 159 | out = out.permute(1, 0, 3, 2).contiguous().view(n, bt, c) 160 | return out 161 | 162 | def make_tokenT1D(net, n_segment, n_div=4, locations_list=[]): 163 | for idx, block in enumerate(net.transformer.resblocks): 164 | if idx in locations_list: 165 | block.control_point = TokenT1D( 166 | in_channels=block.control_point.inplanes, 167 | n_segment=n_segment, 168 | n_div=n_div, 169 | ) 170 | 171 | class InplaceShift(torch.autograd.Function): 172 | # Special thanks to @raoyongming for the help to this function 173 | @staticmethod 174 | def forward(ctx, input, fold): 175 | # not support higher order gradient 176 | # input = input.detach_() 177 | ctx.fold_ = fold 178 | n, t, c, h, w = input.size() 179 | buffer = input.data.new(n, t, fold, h, w).zero_() 180 | buffer[:, :-1] = input.data[:, 1:, :fold] 181 | input.data[:, :, :fold] = buffer 182 | buffer.zero_() 183 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] 184 | input.data[:, :, fold: 2 * fold] = buffer 185 | return input 186 | 187 | @staticmethod 188 | def backward(ctx, grad_output): 189 | # grad_output = grad_output.detach_() 190 | fold = ctx.fold_ 191 | n, t, c, h, w = grad_output.size() 192 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() 193 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] 194 | grad_output.data[:, :, :fold] = buffer 195 | buffer.zero_() 196 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] 197 | grad_output.data[:, :, fold: 2 * fold] = buffer 198 | return grad_output, None 199 | 200 | 201 | class TemporalPool(nn.Module): 202 | def __init__(self, net, n_segment): 203 | super(TemporalPool, self).__init__() 204 | self.net = net 205 | self.n_segment = n_segment 206 | 207 | def forward(self, x): 208 | x = self.temporal_pool(x, n_segment=self.n_segment) 209 | return self.net(x) 210 | 211 | @staticmethod 212 | def temporal_pool(x, n_segment): 213 | nt, c, h, w = x.size() 214 | n_batch = nt // n_segment 215 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w 216 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) 217 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) 218 | return x 219 | 220 | def make_temporal_shift_vit(net, n_segment, n_div=8, place='block', temporal_pool=False): 221 | if temporal_pool: 222 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] 223 | else: 224 | n_segment_list = [n_segment] * 4 225 | assert n_segment_list[-1] > 0 226 | print('=> n_segment per stage: {}'.format(n_segment_list)) 227 | 228 | if isinstance(net, VisualTransformer): 229 | if place == 'block': 230 | def make_block_temporal(stage, this_segment): 231 | blocks = list(stage.children()) 232 | print('=> Processing stage with {} blocks'.format(len(blocks))) 233 | for i, b in enumerate(blocks): 234 | blocks[i] = TemporalShift_VIT(b, n_segment=this_segment, n_div=n_div) 235 | return nn.Sequential(*(blocks)) 236 | 237 | net.transformer.resblocks = make_block_temporal(net.transformer.resblocks, n_segment_list[0]) 238 | 239 | # net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 240 | # net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 241 | # net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 242 | 243 | 244 | else: 245 | raise NotImplementedError(place) 246 | 247 | 248 | 249 | 250 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): 251 | if temporal_pool: 252 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] 253 | else: 254 | n_segment_list = [n_segment] * 4 255 | assert n_segment_list[-1] > 0 256 | print('=> n_segment per stage: {}'.format(n_segment_list)) 257 | 258 | if isinstance(net, ModifiedResNet): 259 | if place == 'block': 260 | def make_block_temporal(stage, this_segment): 261 | blocks = list(stage.children()) 262 | print('=> Processing stage with {} blocks'.format(len(blocks))) 263 | for i, b in enumerate(blocks): 264 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) 265 | return nn.Sequential(*(blocks)) 266 | 267 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 268 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 269 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 270 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 271 | 272 | elif 'blockres' in place: 273 | n_round = 1 274 | if len(list(net.layer3.children())) >= 23: 275 | n_round = 2 276 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) 277 | 278 | def make_block_temporal(stage, this_segment): 279 | blocks = list(stage.children()) 280 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) 281 | for i, b in enumerate(blocks): 282 | if i % n_round == 0: 283 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) 284 | return nn.Sequential(*blocks) 285 | 286 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) 287 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) 288 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) 289 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) 290 | else: 291 | raise NotImplementedError(place) 292 | 293 | 294 | 295 | def make_temporal_pool(net, n_segment): 296 | import torchvision 297 | if isinstance(net, torchvision.models.ResNet): 298 | print('=> Injecting nonlocal pooling') 299 | net.layer2 = TemporalPool(net.layer2, n_segment) 300 | else: 301 | raise NotImplementedError 302 | 303 | 304 | if __name__ == '__main__': 305 | # test inplace shift v.s. vanilla shift 306 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) 307 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) 308 | 309 | print('=> Testing CPU...') 310 | # test forward 311 | with torch.no_grad(): 312 | for i in range(10): 313 | x = torch.rand(2 * 8, 3, 224, 224) 314 | y1 = tsm1(x) 315 | y2 = tsm2(x) 316 | assert torch.norm(y1 - y2).item() < 1e-5 317 | 318 | # test backward 319 | with torch.enable_grad(): 320 | for i in range(10): 321 | x1 = torch.rand(2 * 8, 3, 224, 224) 322 | x1.requires_grad_() 323 | x2 = x1.clone() 324 | y1 = tsm1(x1) 325 | y2 = tsm2(x2) 326 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 327 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 328 | assert torch.norm(grad1 - grad2).item() < 1e-5 329 | 330 | print('=> Testing GPU...') 331 | tsm1.cuda() 332 | tsm2.cuda() 333 | # test forward 334 | with torch.no_grad(): 335 | for i in range(10): 336 | x = torch.rand(2 * 8, 3, 224, 224).cuda() 337 | y1 = tsm1(x) 338 | y2 = tsm2(x) 339 | assert torch.norm(y1 - y2).item() < 1e-5 340 | 341 | # test backward 342 | with torch.enable_grad(): 343 | for i in range(10): 344 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() 345 | x1.requires_grad_() 346 | x2 = x1.clone() 347 | y1 = tsm1(x1) 348 | y2 = tsm2(x2) 349 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] 350 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] 351 | assert torch.norm(grad1 - grad2).item() < 1e-5 352 | print('Test passed.') -------------------------------------------------------------------------------- /modules/text_prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import clip 3 | 4 | def text_prompt(data): 5 | # text_aug = ['{}'] 6 | text_aug = ['a video of a person {}.'] 7 | 8 | # Kinetics 9 | # text_aug = [ 10 | # 'a photo of a person {}.', 11 | # 'a photo of {}.', 12 | # 'a photo of a person using {}.', 13 | # 'a photo of a person doing {}.', 14 | # 'a photo of a person during {}.', 15 | # 'a photo of a person performing {}.', 16 | # 'a photo of a person practicing {}.', 17 | # 'a video of {}.', 18 | # 'a video of a person {}.', 19 | # 'a video of a person using {}.', 20 | # 'a video of a person doing {}.', 21 | # 'a video of a person during {}.', 22 | # 'a video of a person performing {}.', 23 | # 'a video of a person practicing {}.', 24 | # 'a example of {}.', 25 | # 'a example of a person {}.', 26 | # 'a example of a person using {}.', 27 | # 'a example of a person doing {}.', 28 | # 'a example of a person during {}.', 29 | # 'a example of a person performing {}.', 30 | # 'a example of a person practicing {}.', 31 | # 'a demonstration of {}.', 32 | # 'a demonstration of a person {}.', 33 | # 'a demonstration of a person using {}.', 34 | # 'a demonstration of a person doing {}.', 35 | # 'a demonstration of a person during {}.', 36 | # 'a demonstration of a person performing {}.', 37 | # 'a demonstration of a person practicing {}.', 38 | # ] 39 | 40 | text_dict = {} 41 | num_text_aug = len(text_aug) 42 | 43 | for ii, txt in enumerate(text_aug): 44 | text_dict[ii] = torch.cat([clip.tokenize(txt.format(c)) for i, c in data.classes]) 45 | 46 | classes = text_dict[0] 47 | 48 | return classes, num_text_aug, text_dict -------------------------------------------------------------------------------- /modules/video_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import OrderedDict 4 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 5 | 6 | 7 | class LayerNorm(nn.Module): 8 | def __init__(self, hidden_size, eps=1e-12): 9 | """Construct a layernorm module in the TF style (epsilon inside the square root). 10 | """ 11 | super(LayerNorm, self).__init__() 12 | self.weight = nn.Parameter(torch.ones(hidden_size)) 13 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 14 | self.variance_epsilon = eps 15 | 16 | def forward(self, x): 17 | u = x.mean(-1, keepdim=True) 18 | s = (x - u).pow(2).mean(-1, keepdim=True) 19 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 20 | return self.weight * x + self.bias 21 | 22 | class QuickGELU(nn.Module): 23 | def forward(self, x: torch.Tensor): 24 | return x * torch.sigmoid(1.702 * x) 25 | 26 | 27 | class ResidualAttentionBlock(nn.Module): 28 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 29 | super().__init__() 30 | 31 | self.attn = nn.MultiheadAttention(d_model, n_head) 32 | self.ln_1 = LayerNorm(d_model) 33 | self.mlp = nn.Sequential(OrderedDict([ 34 | ("c_fc", nn.Linear(d_model, d_model * 4)), 35 | ("gelu", QuickGELU()), 36 | ("c_proj", nn.Linear(d_model * 4, d_model)) 37 | ])) 38 | self.ln_2 = LayerNorm(d_model) 39 | self.attn_mask = attn_mask 40 | 41 | def attention(self, x: torch.Tensor): 42 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 43 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 44 | 45 | def forward(self, x: torch.Tensor): 46 | x = x + self.attention(self.ln_1(x)) 47 | x = x + self.mlp(self.ln_2(x)) 48 | return x 49 | 50 | 51 | 52 | class TemporalTransformer(nn.Module): 53 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 54 | super().__init__() 55 | self.width = width 56 | self.layers = layers 57 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 58 | 59 | def forward(self, x: torch.Tensor): 60 | return self.resblocks((x)) 61 | 62 | 63 | class video_header(nn.Module): 64 | def __init__(self, vid_head, clip_state_dict): 65 | super().__init__() 66 | self.vid_header = vid_head 67 | assert vid_head in ["None", "Transf"] 68 | 69 | if self.vid_header == "Transf": 70 | embed_dim = clip_state_dict["text_projection"].shape[1] 71 | 72 | context_length = clip_state_dict["positional_embedding"].shape[0] 73 | vocab_size = clip_state_dict["token_embedding.weight"].shape[0] 74 | transformer_width = clip_state_dict["ln_final.weight"].shape[0] 75 | transformer_heads = transformer_width // 64 76 | 77 | transformer_layers = len( 78 | set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks"))) 79 | 80 | self.frame_position_embeddings = nn.Embedding(context_length, embed_dim) 81 | 82 | self.transformer = TemporalTransformer(width=embed_dim, layers=6, heads=transformer_heads) 83 | print('layer=6') 84 | 85 | self.apply(self.init_weights) 86 | 87 | def init_weights(self, module): 88 | """ Initialize the weights. 89 | """ 90 | if isinstance(module, (nn.Linear, nn.Embedding)): 91 | # Slightly different from the TF version which uses truncated_normal for initialization 92 | # cf https://github.com/pytorch/pytorch/pull/5617 93 | module.weight.data.normal_(mean=0.0, std=0.02) 94 | elif isinstance(module, LayerNorm): 95 | if 'beta' in dir(module) and 'gamma' in dir(module): 96 | module.beta.data.zero_() 97 | module.gamma.data.fill_(1.0) 98 | else: 99 | module.bias.data.zero_() 100 | module.weight.data.fill_(1.0) 101 | if isinstance(module, nn.Linear) and module.bias is not None: 102 | module.bias.data.zero_() 103 | 104 | def forward(self, x): 105 | b, t, c = x.size() 106 | x = x.contiguous() 107 | if self.vid_header == "None": 108 | pass 109 | 110 | elif self.vid_header == "Transf": 111 | x_original = x 112 | seq_length = t 113 | position_ids = torch.arange(seq_length, dtype=torch.long, device=x.device) 114 | position_ids = position_ids.unsqueeze(0).expand(x.size(0), -1) 115 | frame_position_embeddings = self.frame_position_embeddings(position_ids) 116 | x = x + frame_position_embeddings 117 | 118 | x = x.permute(1, 0, 2) # NLD -> LND 119 | x = self.transformer(x) 120 | x = x.permute(1, 0, 2) # LND -> NLD 121 | x = x.type(x_original.dtype) + x_original 122 | 123 | else: 124 | raise ValueError('Unknown temporal modeling header: {}'.format(self.vid_header)) 125 | return x.mean(dim=1, keepdim=False) 126 | 127 | 128 | class VideoCLIP(nn.Module): 129 | def __init__(self, clip_model, video_header, n_seg) : 130 | super(VideoCLIP, self).__init__() 131 | self.visual = clip_model.visual 132 | self.fusion_model = video_header 133 | self.n_seg = n_seg 134 | self.logit_scale = clip_model.logit_scale 135 | 136 | def forward(self, image, text_emb): 137 | image_emb = self.encode_image(image) 138 | image_emb = image_emb / image_emb.norm(dim=-1, keepdim=True) 139 | text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True) 140 | logit_scale = self.logit_scale.exp() 141 | logits = logit_scale * image_emb @ text_emb.t() 142 | return logits 143 | 144 | def encode_image(self, image): 145 | bt = image.size(0) 146 | b = bt // self.n_seg 147 | image_emb = self.visual(image) 148 | if image_emb.size(0) == b: # joint 149 | return image_emb 150 | else: 151 | image_emb = image_emb.view(b, self.n_seg, -1) 152 | image_emb = self.fusion_model(image_emb) 153 | return image_emb -------------------------------------------------------------------------------- /scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | if [ -f $1 ]; then 3 | config=$1 4 | else 5 | echo "need a config file" 6 | exit 7 | fi 8 | 9 | weight=$2 10 | 11 | python -m torch.distributed.launch --master_port 1238 --nproc_per_node=8 \ 12 | test.py --config ${config} --weights ${weight} ${@:3} -------------------------------------------------------------------------------- /scripts/run_test_zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | if [ -f $1 ]; then 3 | config=$1 4 | else 5 | echo "need a config file" 6 | exit 7 | fi 8 | 9 | weight=$2 10 | 11 | python -m torch.distributed.launch --master_port 1238 --nproc_per_node=8 \ 12 | test_zeroshot.py --config ${config} --weights ${weight} ${@:3} -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | if [ -f $1 ]; then 3 | config=$1 4 | else 5 | echo "need a config file" 6 | exit 7 | fi 8 | 9 | now=$(date +"%Y%m%d_%H%M%S") 10 | python -m torch.distributed.launch --master_port 1237 --nproc_per_node=8 \ 11 | train.py --config ${config} --log_time $now -------------------------------------------------------------------------------- /scripts/run_train_multinodes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #--- Multi-nodes training hyperparams --- 4 | nnodes=2 5 | master_addr="10.67.212.21" 6 | 7 | # Note: 8 | # 0. You need to set the master ip address according to your own machines. 9 | # 1. You'd better to scale the learning rate when you use more gpus. 10 | # 2. Command: sh scripts/run_train_multinodes.sh node_rank 11 | ############################################# 12 | if [ -f $1 ]; then 13 | config=$1 14 | else 15 | echo "need a config file" 16 | exit 17 | fi 18 | 19 | python -m torch.distributed.launch --master_port 12355 --nproc_per_node=8 \ 20 | --nnodes=${nnodes} --node_rank=$2 \ 21 | --master_addr=${master_addr} \ 22 | train.py --config ${config} ${@:3} 23 | -------------------------------------------------------------------------------- /scripts/run_train_nce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | if [ -f $1 ]; then 3 | config=$1 4 | else 5 | echo "need a config file" 6 | exit 7 | fi 8 | 9 | now=$(date +"%Y%m%d_%H%M%S") 10 | python -m torch.distributed.launch --master_port 1237 --nproc_per_node=4 \ 11 | train_nce.py --config ${config} --log_time $now -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/Text4Vis/d61b34d0208a03ce6146edcc51033b2a040cb249/teaser.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | from torch.nn.parallel import DistributedDataParallel 8 | import torch.distributed as dist 9 | import torch.backends.cudnn as cudnn 10 | import torchvision 11 | import time 12 | from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy 13 | import clip 14 | 15 | import yaml 16 | from dotmap import DotMap 17 | 18 | from datasets import Video_dataset 19 | from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample 20 | from modules.video_clip import video_header, VideoCLIP 21 | from modules.text_prompt import text_prompt 22 | 23 | 24 | 25 | def get_parser(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--config', type=str, help='global config file') 28 | parser.add_argument('--weights', type=str, default=None) 29 | parser.add_argument('--dist_url', default='env://', 30 | help='url used to set up distributed training') 31 | parser.add_argument('--world_size', default=1, type=int, 32 | help='number of distributed processes') 33 | parser.add_argument("--local_rank", type=int, 34 | help='local rank for DistributedDataParallel') 35 | parser.add_argument( 36 | "--precision", 37 | choices=["amp", "fp16", "fp32"], 38 | default="amp", 39 | help="Floating point precition." 40 | ) 41 | parser.add_argument('--test_crops', type=int, default=1) 42 | parser.add_argument('--test_clips', type=int, default=1) 43 | parser.add_argument('--dense', default=False, action="store_true", 44 | help='use multiple clips for test') 45 | args = parser.parse_args() 46 | return args 47 | 48 | def update_dict(dict): 49 | new_dict = {} 50 | for k, v in dict.items(): 51 | new_dict[k.replace('module.', '')] = v 52 | return new_dict 53 | 54 | 55 | def main(args): 56 | init_distributed_mode(args) 57 | 58 | with open(args.config, 'r') as f: 59 | config = yaml.load(f, Loader=yaml.FullLoader) 60 | 61 | config = DotMap(config) 62 | 63 | device = "cpu" 64 | if torch.cuda.is_available(): 65 | device = "cuda" 66 | cudnn.benchmark = True 67 | 68 | # get fp16 model and weight 69 | model, clip_state_dict = clip.load( 70 | config.network.arch, 71 | device='cpu',jit=False, 72 | internal_modeling=config.network.tm, 73 | T=config.data.num_segments, 74 | dropout=config.network.drop_out, 75 | emb_dropout=config.network.emb_dropout, 76 | pretrain=config.network.init, 77 | joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32 78 | 79 | video_head = video_header( 80 | config.network.sim_header, 81 | clip_state_dict) 82 | 83 | if args.precision == "amp" or args.precision == "fp32": 84 | model = model.float() 85 | 86 | 87 | input_mean = [0.48145466, 0.4578275, 0.40821073] 88 | input_std = [0.26862954, 0.26130258, 0.27577711] 89 | 90 | # rescale size 91 | if 'something' in config.data.dataset: 92 | scale_size = (240, 320) 93 | else: 94 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size 95 | 96 | # crop size 97 | input_size = config.data.input_size 98 | 99 | # control the spatial crop 100 | if args.test_crops == 1: # one crop 101 | cropping = torchvision.transforms.Compose([ 102 | GroupScale(scale_size), 103 | GroupCenterCrop(input_size), 104 | ]) 105 | elif args.test_crops == 3: # do not flip, so only 3 crops (left right center) 106 | cropping = torchvision.transforms.Compose([ 107 | GroupFullResSample( 108 | crop_size=input_size, 109 | scale_size=scale_size, 110 | flip=False) 111 | ]) 112 | elif args.test_crops == 5: # do not flip, so only 5 crops 113 | cropping = torchvision.transforms.Compose([ 114 | GroupOverSample( 115 | crop_size=input_size, 116 | scale_size=scale_size, 117 | flip=False) 118 | ]) 119 | elif args.test_crops == 10: 120 | cropping = torchvision.transforms.Compose([ 121 | GroupOverSample( 122 | crop_size=input_size, 123 | scale_size=scale_size, 124 | ) 125 | ]) 126 | else: 127 | raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops)) 128 | 129 | 130 | val_data = Video_dataset( 131 | config.data.val_root, config.data.val_list, config.data.label_list, 132 | random_shift=False, num_segments=config.data.num_segments, 133 | modality=config.data.modality, 134 | image_tmpl=config.data.image_tmpl, 135 | test_mode=True, 136 | transform=torchvision.transforms.Compose([ 137 | cropping, 138 | Stack(roll=False), 139 | ToTorchFormatTensor(div=True), 140 | GroupNormalize(input_mean,input_std), 141 | ]), 142 | dense_sample=args.dense, 143 | test_clips=args.test_clips) 144 | 145 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) 146 | val_loader = DataLoader(val_data, 147 | batch_size=config.data.batch_size,num_workers=config.data.workers, 148 | sampler=val_sampler, pin_memory=True, drop_last=False) 149 | 150 | 151 | model_full = VideoCLIP(model, video_head, config.data.num_segments) 152 | 153 | if os.path.isfile(args.weights): 154 | checkpoint = torch.load(args.weights, map_location='cpu') 155 | if dist.get_rank() == 0: 156 | print('load model: epoch {}'.format(checkpoint['epoch'])) 157 | 158 | model_full.load_state_dict(update_dict(checkpoint['model_state_dict'])) 159 | del checkpoint 160 | 161 | if args.distributed: 162 | model_full = DistributedDataParallel(model_full.cuda(), device_ids=[args.gpu], find_unused_parameters=True) 163 | 164 | 165 | classes, num_text_aug, text_dict = text_prompt(val_data) 166 | n_class = text_dict[0].size(0) 167 | #### generate classes feature ###### 168 | class_feats_file = 'text_feats_{}_{}.pt'.format(config['data']['dataset'], config['network']['arch']).replace('/','') 169 | if os.path.isfile(class_feats_file): 170 | print('=> load classes features from {}'.format(class_feats_file)) 171 | classes_features = torch.load(class_feats_file) 172 | else: 173 | model.eval() 174 | with torch.no_grad(): 175 | classes_features = model.encode_text(classes) # 400 512 176 | # if dist.get_rank() == 0: 177 | # torch.save(classes_features.cpu(), class_feats_file) 178 | 179 | 180 | prec1 = validate( 181 | val_loader, device, 182 | model_full, config, classes_features, args.test_crops, args.test_clips) 183 | 184 | return 185 | 186 | 187 | def validate_rumtime(val_loader, device, model, config, text_features, test_crops, test_clips): 188 | 189 | model.eval() 190 | with torch.no_grad(): 191 | batch_size = config.data.batch_size 192 | image = torch.rand(batch_size, config.data.num_segments, 3, config.data.input_size, config.data.input_size) 193 | b, t, c, h, w = image.size() 194 | proc_start_time = time.time() 195 | 196 | for i in range(2000): 197 | image_input = image.to(device).view(-1, c, h, w) 198 | image_features = model.module.encode_image(image_input) 199 | cnt_time = time.time() - proc_start_time 200 | 201 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0: 202 | runtime = float(cnt_time) / (i+1) / (batch_size) 203 | print( 204 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'.format( 205 | i, 2000, runtime=runtime))) 206 | return cnt_time 207 | 208 | 209 | def validate(val_loader, device, model, config, text_features, test_crops, test_clips): 210 | 211 | top1 = AverageMeter() 212 | top5 = AverageMeter() 213 | model.eval() 214 | proc_start_time = time.time() 215 | 216 | with torch.no_grad(): 217 | n_class = text_features.size(0) 218 | 219 | for i, (image, class_id) in enumerate(val_loader): 220 | batch_size = class_id.numel() 221 | num_crop = test_crops 222 | 223 | num_crop *= test_clips # 4 clips for testing when using dense sample 224 | 225 | class_id = class_id.to(device) 226 | text_features = text_features.to(device) 227 | n_seg = config.data.num_segments 228 | image = image.view((-1, n_seg, 3) + image.size()[-2:]) 229 | b, t, c, h, w = image.size() 230 | image_input = image.to(device).view(-1, c, h, w) 231 | image_features = model.module.encode_image(image_input) 232 | cnt_time = time.time() - proc_start_time 233 | image_features = image_features.reshape(batch_size, num_crop, -1).mean(1) # bs dim 234 | 235 | image_features /= image_features.norm(dim=-1, keepdim=True) 236 | text_features /= text_features.norm(dim=-1, keepdim=True) 237 | similarity = (100.0 * image_features @ text_features.T) 238 | similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1) 239 | similarity = similarity.mean(dim=1, keepdim=False) # bs 200 240 | 241 | prec = accuracy(similarity, class_id, topk=(1, 5)) 242 | prec1 = reduce_tensor(prec[0]) 243 | prec5 = reduce_tensor(prec[1]) 244 | 245 | top1.update(prec1.item(), class_id.size(0)) 246 | top5.update(prec5.item(), class_id.size(0)) 247 | 248 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0: 249 | runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size()) 250 | print( 251 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t' 252 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 253 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 254 | i, len(val_loader), runtime=runtime, top1=top1, top5=top5))) 255 | 256 | if dist.get_rank() == 0: 257 | print('-----Evaluation is finished------') 258 | print('Overall Prec@1 {:.03f}% Prec@5 {:.03f}%'.format(top1.avg, top5.avg)) 259 | 260 | return top1.avg 261 | 262 | 263 | 264 | if __name__ == '__main__': 265 | args = get_parser() 266 | main(args) 267 | 268 | -------------------------------------------------------------------------------- /test_anet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | from torch.nn.parallel import DistributedDataParallel 8 | import torch.distributed as dist 9 | import torch.backends.cudnn as cudnn 10 | import torchvision 11 | import time 12 | from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy, mean_average_precision 13 | import clip 14 | 15 | import yaml 16 | from dotmap import DotMap 17 | 18 | from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample 19 | from modules.video_clip import video_header, VideoCLIP 20 | from modules.text_prompt import text_prompt 21 | 22 | 23 | 24 | def get_parser(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--config', type=str, help='global config file') 27 | parser.add_argument('--weights', type=str, default=None) 28 | parser.add_argument('--dist_url', default='env://', 29 | help='url used to set up distributed training') 30 | parser.add_argument('--world_size', default=1, type=int, 31 | help='number of distributed processes') 32 | parser.add_argument("--local_rank", type=int, 33 | help='local rank for DistributedDataParallel') 34 | parser.add_argument( 35 | "--precision", 36 | choices=["amp", "fp16", "fp32"], 37 | default="amp", 38 | help="Floating point precition." 39 | ) 40 | parser.add_argument('--test_crops', type=int, default=1) 41 | parser.add_argument('--test_clips', type=int, default=1) 42 | parser.add_argument('--dense', default=False, action="store_true", 43 | help='use multiple clips for test') 44 | args = parser.parse_args() 45 | return args 46 | 47 | def update_dict(dict): 48 | new_dict = {} 49 | for k, v in dict.items(): 50 | new_dict[k.replace('module.', '')] = v 51 | return new_dict 52 | 53 | 54 | def main(args): 55 | init_distributed_mode(args) 56 | 57 | with open(args.config, 'r') as f: 58 | config = yaml.load(f, Loader=yaml.FullLoader) 59 | 60 | if 'something' in config['data']['dataset']: 61 | from datasets.sth import Video_dataset 62 | else: 63 | from datasets.kinetics import Video_dataset 64 | 65 | config = DotMap(config) 66 | 67 | device = "cpu" 68 | if torch.cuda.is_available(): 69 | device = "cuda" 70 | cudnn.benchmark = True 71 | 72 | # get fp16 model and weight 73 | model, clip_state_dict = clip.load( 74 | config.network.arch, 75 | device='cpu',jit=False, 76 | internal_modeling=config.network.tm, 77 | T=config.data.num_segments, 78 | dropout=config.network.drop_out, 79 | emb_dropout=config.network.emb_dropout, 80 | pretrain=config.network.init, 81 | joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32 82 | 83 | video_head = video_header( 84 | config.network.sim_header, 85 | clip_state_dict) 86 | 87 | if args.precision == "amp" or args.precision == "fp32": 88 | model = model.float() 89 | 90 | 91 | input_mean = [0.48145466, 0.4578275, 0.40821073] 92 | input_std = [0.26862954, 0.26130258, 0.27577711] 93 | 94 | # rescale size 95 | if 'something' in config.data.dataset: 96 | scale_size = (240, 320) 97 | else: 98 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size 99 | 100 | # crop size 101 | input_size = config.data.input_size 102 | 103 | # control the spatial crop 104 | if args.test_crops == 1: # one crop 105 | cropping = torchvision.transforms.Compose([ 106 | GroupScale(scale_size), 107 | GroupCenterCrop(input_size), 108 | ]) 109 | elif args.test_crops == 3: # do not flip, so only 3 crops (left right center) 110 | cropping = torchvision.transforms.Compose([ 111 | GroupFullResSample( 112 | crop_size=input_size, 113 | scale_size=scale_size, 114 | flip=False) 115 | ]) 116 | elif args.test_crops == 5: # do not flip, so only 5 crops 117 | cropping = torchvision.transforms.Compose([ 118 | GroupOverSample( 119 | crop_size=input_size, 120 | scale_size=scale_size, 121 | flip=False) 122 | ]) 123 | elif args.test_crops == 10: 124 | cropping = torchvision.transforms.Compose([ 125 | GroupOverSample( 126 | crop_size=input_size, 127 | scale_size=scale_size, 128 | ) 129 | ]) 130 | else: 131 | raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops)) 132 | 133 | 134 | val_data = Video_dataset( 135 | config.data.val_root, config.data.val_list, config.data.label_list, 136 | random_shift=False, num_segments=config.data.num_segments, 137 | modality=config.data.modality, 138 | image_tmpl=config.data.image_tmpl, 139 | test_mode=True, 140 | transform=torchvision.transforms.Compose([ 141 | cropping, 142 | Stack(roll=False), 143 | ToTorchFormatTensor(div=True), 144 | GroupNormalize(input_mean,input_std), 145 | ]), 146 | dense_sample=args.dense, 147 | test_clips=args.test_clips) 148 | 149 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) 150 | val_loader = DataLoader(val_data, 151 | batch_size=config.data.batch_size,num_workers=config.data.workers, 152 | sampler=val_sampler, pin_memory=True, drop_last=False) 153 | 154 | 155 | model_full = VideoCLIP(model, video_head, config.data.num_segments) 156 | 157 | if os.path.isfile(args.weights): 158 | checkpoint = torch.load(args.weights, map_location='cpu') 159 | if dist.get_rank() == 0: 160 | print('load model: epoch {}'.format(checkpoint['epoch'])) 161 | 162 | model_full.load_state_dict(update_dict(checkpoint['model_state_dict'])) 163 | del checkpoint 164 | 165 | if args.distributed: 166 | model_full = DistributedDataParallel(model_full.cuda(), device_ids=[args.gpu], find_unused_parameters=True) 167 | 168 | 169 | classes, num_text_aug, text_dict = text_prompt(val_data) 170 | n_class = text_dict[0].size(0) 171 | #### generate classes feature ###### 172 | class_feats_file = 'text_feats_{}_{}.pt'.format(config['data']['dataset'], config['network']['arch']).replace('/','') 173 | if os.path.isfile(class_feats_file): 174 | print('=> load classes features from {}'.format(class_feats_file)) 175 | classes_features = torch.load(class_feats_file) 176 | else: 177 | model.eval() 178 | with torch.no_grad(): 179 | classes_features = model.encode_text(classes) # 400 512 180 | # if dist.get_rank() == 0: 181 | # torch.save(classes_features.cpu(), class_feats_file) 182 | 183 | 184 | prec1 = validate( 185 | val_loader, device, 186 | model_full, config, classes_features, args.test_crops, args.test_clips) 187 | 188 | return 189 | 190 | 191 | def validate_rumtime(val_loader, device, model, config, text_features, test_crops, test_clips): 192 | 193 | model.eval() 194 | with torch.no_grad(): 195 | batch_size = config.data.batch_size 196 | image = torch.rand(batch_size, config.data.num_segments, 3, config.data.input_size, config.data.input_size) 197 | b, t, c, h, w = image.size() 198 | proc_start_time = time.time() 199 | 200 | for i in range(2000): 201 | image_input = image.to(device).view(-1, c, h, w) 202 | image_features = model.module.encode_image(image_input) 203 | cnt_time = time.time() - proc_start_time 204 | 205 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0: 206 | runtime = float(cnt_time) / (i+1) / (batch_size) 207 | print( 208 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'.format( 209 | i, 2000, runtime=runtime))) 210 | return cnt_time 211 | 212 | 213 | def validate(val_loader, device, model, config, text_features, test_crops, test_clips): 214 | 215 | top1 = AverageMeter() 216 | top5 = AverageMeter() 217 | model.eval() 218 | proc_start_time = time.time() 219 | 220 | sim_logits = [] 221 | labels = [] 222 | 223 | with torch.no_grad(): 224 | n_class = text_features.size(0) 225 | 226 | for i, (image, class_id) in enumerate(val_loader): 227 | batch_size = class_id.numel() 228 | num_crop = test_crops 229 | 230 | num_crop *= test_clips # 4 clips for testing when using dense sample 231 | 232 | class_id = class_id.to(device) 233 | text_features = text_features.to(device) 234 | n_seg = config.data.num_segments 235 | image = image.view((-1, n_seg, 3) + image.size()[-2:]) 236 | b, t, c, h, w = image.size() 237 | image_input = image.to(device).view(-1, c, h, w) 238 | image_features = model.module.encode_image(image_input) 239 | cnt_time = time.time() - proc_start_time 240 | image_features = image_features.reshape(batch_size, num_crop, -1).mean(1) # bs dim 241 | 242 | image_features /= image_features.norm(dim=-1, keepdim=True) 243 | text_features /= text_features.norm(dim=-1, keepdim=True) 244 | similarity = (100.0 * image_features @ text_features.T) 245 | similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1) 246 | similarity = similarity.mean(dim=1, keepdim=False) # bs 200 247 | 248 | ########## for saving 249 | sim_logits.append(concat_all_gather(similarity)) 250 | labels.append(concat_all_gather(class_id)) 251 | ########## 252 | 253 | prec = accuracy(similarity, class_id, topk=(1, 5)) 254 | prec1 = reduce_tensor(prec[0]) 255 | prec5 = reduce_tensor(prec[1]) 256 | 257 | top1.update(prec1.item(), class_id.size(0)) 258 | top5.update(prec5.item(), class_id.size(0)) 259 | 260 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0: 261 | runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size()) 262 | print( 263 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t' 264 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 265 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 266 | i, len(val_loader), runtime=runtime, top1=top1, top5=top5))) 267 | 268 | if dist.get_rank() == 0: 269 | print('-----Evaluation is finished------') 270 | print('Overall Prec@1 {:.03f}% Prec@5 {:.03f}%'.format(top1.avg, top5.avg)) 271 | 272 | sim, gt = sim_logits[0], labels[0] 273 | for i in range(1, len(sim_logits)): 274 | sim = torch.cat((sim, sim_logits[i]), 0) 275 | gt = torch.cat((gt, labels[i]), 0) 276 | 277 | if dist.get_rank() == 0: 278 | mAP = mean_average_precision(sim, gt) 279 | print('Overall mAP: {:.03f}%'.format(mAP[1].item())) 280 | 281 | return top1.avg 282 | 283 | 284 | # utils 285 | @torch.no_grad() 286 | def concat_all_gather(tensor): 287 | """ 288 | Performs all_gather operation on the provided tensors. 289 | *** Warning ***: torch.distributed.all_gather has no gradient. 290 | """ 291 | tensors_gather = [torch.ones_like(tensor) 292 | for _ in range(torch.distributed.get_world_size())] 293 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 294 | 295 | output = torch.cat(tensors_gather, dim=0) 296 | return output.cpu() 297 | 298 | 299 | 300 | if __name__ == '__main__': 301 | args = get_parser() 302 | main(args) 303 | 304 | -------------------------------------------------------------------------------- /test_zeroshot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | from torch.nn.parallel import DistributedDataParallel 8 | import torch.distributed as dist 9 | import torch.backends.cudnn as cudnn 10 | import torchvision 11 | import time 12 | from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy 13 | import clip 14 | 15 | import yaml 16 | from dotmap import DotMap 17 | 18 | from datasets import Video_dataset 19 | from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample 20 | from modules.video_clip import video_header, VideoCLIP 21 | from modules.text_prompt import text_prompt 22 | 23 | 24 | 25 | def get_parser(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--config', type=str, help='global config file') 28 | parser.add_argument('--weights', type=str, default=None) 29 | parser.add_argument('--dist_url', default='env://', 30 | help='url used to set up distributed training') 31 | parser.add_argument('--world_size', default=1, type=int, 32 | help='number of distributed processes') 33 | parser.add_argument("--local_rank", type=int, 34 | help='local rank for DistributedDataParallel') 35 | parser.add_argument( 36 | "--precision", 37 | choices=["amp", "fp16", "fp32"], 38 | default="amp", 39 | help="Floating point precition." 40 | ) 41 | parser.add_argument('--test_crops', type=int, default=1) 42 | parser.add_argument('--test_clips', type=int, default=1) 43 | parser.add_argument('--dense', default=False, action="store_true", 44 | help='use multiple clips for test') 45 | args = parser.parse_args() 46 | return args 47 | 48 | def update_dict(dict): 49 | new_dict = {} 50 | for k, v in dict.items(): 51 | new_dict[k.replace('module.', '')] = v 52 | return new_dict 53 | 54 | 55 | def main(args): 56 | init_distributed_mode(args) 57 | 58 | with open(args.config, 'r') as f: 59 | config = yaml.load(f, Loader=yaml.FullLoader) 60 | 61 | config = DotMap(config) 62 | 63 | device = "cpu" 64 | if torch.cuda.is_available(): 65 | device = "cuda" 66 | cudnn.benchmark = True 67 | 68 | # get fp16 model and weight 69 | model, clip_state_dict = clip.load( 70 | config.network.arch, 71 | device='cpu',jit=False, 72 | internal_modeling=config.network.tm, 73 | T=config.data.num_segments, 74 | dropout=config.network.drop_out, 75 | emb_dropout=config.network.emb_dropout, 76 | pretrain=config.network.init, 77 | joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32 78 | 79 | video_head = video_header( 80 | config.network.sim_header, 81 | clip_state_dict) 82 | 83 | if args.precision == "amp" or args.precision == "fp32": 84 | model = model.float() 85 | 86 | 87 | input_mean = [0.48145466, 0.4578275, 0.40821073] 88 | input_std = [0.26862954, 0.26130258, 0.27577711] 89 | 90 | # rescale size 91 | if 'something' in config.data.dataset: 92 | scale_size = (240, 320) 93 | else: 94 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size 95 | 96 | # crop size 97 | input_size = config.data.input_size 98 | 99 | # control the spatial crop 100 | if args.test_crops == 1: # one crop 101 | cropping = torchvision.transforms.Compose([ 102 | GroupScale(scale_size), 103 | GroupCenterCrop(input_size), 104 | ]) 105 | elif args.test_crops == 3: # do not flip, so only 3 crops (left right center) 106 | cropping = torchvision.transforms.Compose([ 107 | GroupFullResSample( 108 | crop_size=input_size, 109 | scale_size=scale_size, 110 | flip=False) 111 | ]) 112 | elif args.test_crops == 5: # do not flip, so only 5 crops 113 | cropping = torchvision.transforms.Compose([ 114 | GroupOverSample( 115 | crop_size=input_size, 116 | scale_size=scale_size, 117 | flip=False) 118 | ]) 119 | elif args.test_crops == 10: 120 | cropping = torchvision.transforms.Compose([ 121 | GroupOverSample( 122 | crop_size=input_size, 123 | scale_size=scale_size, 124 | ) 125 | ]) 126 | else: 127 | raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops)) 128 | 129 | 130 | val_data = Video_dataset( 131 | config.data.val_root, config.data.val_list, config.data.label_list, 132 | random_shift=False, num_segments=config.data.num_segments, 133 | modality=config.data.modality, 134 | image_tmpl=config.data.image_tmpl, 135 | test_mode=True, 136 | transform=torchvision.transforms.Compose([ 137 | cropping, 138 | Stack(roll=False), 139 | ToTorchFormatTensor(div=True), 140 | GroupNormalize(input_mean,input_std), 141 | ]), 142 | dense_sample=args.dense, 143 | test_clips=args.test_clips) 144 | 145 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) 146 | val_loader = DataLoader(val_data, 147 | batch_size=config.data.batch_size,num_workers=config.data.workers, 148 | sampler=val_sampler, pin_memory=True, drop_last=False) 149 | 150 | 151 | model_full = VideoCLIP(model, video_head, config.data.num_segments) 152 | 153 | if os.path.isfile(args.weights): 154 | checkpoint = torch.load(args.weights, map_location='cpu') 155 | if dist.get_rank() == 0: 156 | print('load model: epoch {}'.format(checkpoint['epoch'])) 157 | 158 | model_full.load_state_dict(update_dict(checkpoint['model_state_dict'])) 159 | del checkpoint 160 | 161 | if args.distributed: 162 | model_full = DistributedDataParallel(model_full.cuda(), device_ids=[args.gpu], find_unused_parameters=True) 163 | 164 | 165 | classes, num_text_aug, text_dict = text_prompt(val_data) 166 | n_class = text_dict[0].size(0) 167 | #### generate classes feature ###### 168 | class_feats_file = 'text_feats_{}_{}.pt'.format(config['data']['dataset'], config['network']['arch']).replace('/','') 169 | if os.path.isfile(class_feats_file): 170 | print('=> load classes features from {}'.format(class_feats_file)) 171 | classes_features = torch.load(class_feats_file) 172 | else: 173 | model.eval() 174 | with torch.no_grad(): 175 | classes_features = model.encode_text(classes) # 400 512 176 | # if dist.get_rank() == 0: 177 | # torch.save(classes_features.cpu(), class_feats_file) 178 | 179 | 180 | prec1 = validate( 181 | val_loader, device, 182 | model_full, config, classes_features, args.test_crops, args.test_clips) 183 | 184 | return 185 | 186 | 187 | 188 | def validate(val_loader, device, model, config, text_features, test_crops, test_clips): 189 | 190 | top1 = AverageMeter() 191 | top5 = AverageMeter() 192 | model.eval() 193 | proc_start_time = time.time() 194 | 195 | sim_logits = [] # 196 | labels = [] # 197 | i_features = [] 198 | 199 | with torch.no_grad(): 200 | n_class = text_features.size(0) 201 | 202 | for i, (image, class_id) in enumerate(val_loader): 203 | batch_size = class_id.numel() 204 | num_crop = test_crops 205 | 206 | num_crop *= test_clips # 4 clips for testing when using dense sample 207 | 208 | class_id = class_id.to(device) 209 | text_features = text_features.to(device) 210 | n_seg = config.data.num_segments 211 | image = image.view((-1, n_seg, 3) + image.size()[-2:]) 212 | b, t, c, h, w = image.size() 213 | image_input = image.to(device).view(-1, c, h, w) 214 | image_features = model.module.encode_image(image_input) 215 | cnt_time = time.time() - proc_start_time 216 | image_features = image_features.reshape(batch_size, num_crop, -1).mean(1) # bs dim 217 | 218 | image_features /= image_features.norm(dim=-1, keepdim=True) 219 | text_features /= text_features.norm(dim=-1, keepdim=True) 220 | similarity = (100.0 * image_features @ text_features.T) 221 | similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1) 222 | similarity = similarity.mean(dim=1, keepdim=False) # bs 200 223 | 224 | ########## gathering 225 | i_features.append(concat_all_gather(image_features)) 226 | sim_logits.append(concat_all_gather(similarity)) 227 | labels.append(concat_all_gather(class_id)) 228 | ########## 229 | 230 | 231 | prec = accuracy(similarity, class_id, topk=(1, 5)) 232 | prec1 = reduce_tensor(prec[0]) 233 | prec5 = reduce_tensor(prec[1]) 234 | 235 | top1.update(prec1.item(), class_id.size(0)) 236 | top5.update(prec5.item(), class_id.size(0)) 237 | 238 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0: 239 | runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size()) 240 | print( 241 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t' 242 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 243 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 244 | i, len(val_loader), runtime=runtime, top1=top1, top5=top5))) 245 | 246 | if dist.get_rank() == 0: 247 | ## half-classes evaluation 248 | sim, la = sim_logits[0], labels[0] 249 | vid_feat = i_features[0] 250 | for i in range(1, len(sim_logits)): 251 | sim = torch.cat((sim, sim_logits[i]), 0) 252 | la = torch.cat((la, labels[i]), 0) 253 | vid_feat = torch.cat((vid_feat, i_features[i]), 0) 254 | 255 | acc_split, acc_split_top5 = multi_split_test(vid_feat.cpu(), text_features.cpu(), la.cpu()) 256 | accuracy_split, accuracy_split_std = np.mean(acc_split), np.std(acc_split) 257 | accuracy_split_top5, accuracy_split_top5_std = np.mean(acc_split_top5), np.std(acc_split_top5) 258 | print('-----Half-classes Evaluation-----') 259 | print('Top1: mean {:.03f}%, std {:.03f}%'.format(accuracy_split, accuracy_split_std)) 260 | print('Top5: mean {:.03f}%, std {:.03f}%'.format(accuracy_split_top5, accuracy_split_top5_std)) 261 | 262 | return top1.avg 263 | 264 | # utils 265 | @torch.no_grad() 266 | def concat_all_gather(tensor): 267 | """ 268 | Performs all_gather operation on the provided tensors. 269 | *** Warning ***: torch.distributed.all_gather has no gradient. 270 | """ 271 | tensors_gather = [torch.ones_like(tensor) 272 | for _ in range(torch.distributed.get_world_size())] 273 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 274 | 275 | output = torch.cat(tensors_gather, dim=0) 276 | return output.cpu() 277 | 278 | def compute_accuracy(vis_emb, text_emb, label): 279 | n_class = len(text_emb) 280 | n_samples = len(vis_emb) 281 | similarity=(100.0 * vis_emb @ text_emb.T) 282 | similarity=similarity.view(n_samples, -1, n_class).softmax(dim = -1) 283 | similarity=similarity.mean(dim = 1, keepdim = False) # b 101 284 | prec=accuracy(similarity, label, topk = (1, 5)) 285 | return prec[0], prec[1] 286 | 287 | 288 | def multi_split_test(vis_embs, text_embs, true_label): 289 | # vis_embs: [10000, 768] 290 | # text_embs: [101, 768] 291 | # true_label: [10000,] 292 | full_acc1, full_acc5 = compute_accuracy(vis_embs, text_embs, true_label) 293 | print('-----Full-classes Evaluation------') 294 | print('Overall Top1 {:.03f}% Top5 {:.03f}%'.format(full_acc1.item(), full_acc5.item())) 295 | 296 | # Calculate accuracy per split 297 | # Only when the model has been trained on a different dataset 298 | true_label = true_label.numpy() 299 | accuracy_split, accuracy_split_top5 = np.zeros(10), np.zeros(10) 300 | for split in range(len(accuracy_split)): 301 | np.random.seed(split) 302 | sel_classes = np.random.permutation(len(text_embs))[:len(text_embs) // 2] # [50, ] 303 | sel = [l in sel_classes for l in true_label] # len = 10000 304 | subclasses = np.unique(true_label[sel]) # [50, ] 305 | # label_map = {} 306 | # for i in range(len(subclasses)): 307 | # label_map[subclasses[i]] = i 308 | # new_label = np.array([label_map[l] for l in true_label[sel]]) 309 | # new_label = torch.from_numpy(new_label) 310 | # label mapping: [4900, ], new label 311 | tl = np.array([int(np.where(l == subclasses)[0]) for l in true_label[sel]]) 312 | tl = torch.from_numpy(tl) 313 | acc, acc5 = compute_accuracy(vis_embs[sel], text_embs[subclasses], tl) 314 | accuracy_split[split] = acc 315 | accuracy_split_top5[split] = acc5 316 | 317 | return accuracy_split, accuracy_split_top5 318 | 319 | if __name__ == '__main__': 320 | args = get_parser() 321 | main(args) 322 | 323 | -------------------------------------------------------------------------------- /text4vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/Text4Vis/d61b34d0208a03ce6146edcc51033b2a040cb249/text4vis.png -------------------------------------------------------------------------------- /utils/Augmentation.py: -------------------------------------------------------------------------------- 1 | from datasets.transforms import * 2 | 3 | def train_augmentation(input_size, flip=True): 4 | if flip: 5 | return torchvision.transforms.Compose([ 6 | GroupRandomSizedCrop(input_size), 7 | GroupRandomHorizontalFlip(is_flow=False)]) 8 | else: 9 | return torchvision.transforms.Compose([ 10 | GroupMultiScaleCrop(input_size, [1, .875, .75, .66]), 11 | GroupRandomHorizontalFlip_sth()]) 12 | 13 | 14 | def get_augmentation(training, config): 15 | input_mean = [0.48145466, 0.4578275, 0.40821073] 16 | input_std = [0.26862954, 0.26130258, 0.27577711] 17 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size 18 | 19 | normalize = GroupNormalize(input_mean, input_std) 20 | if 'something' in config.data.dataset: 21 | groupscale = GroupScale((240, 320)) 22 | else: 23 | groupscale = GroupScale(int(scale_size)) 24 | 25 | if training: 26 | train_aug = train_augmentation( 27 | config.data.input_size, 28 | flip=False if 'something' in config.data.dataset else True) 29 | 30 | unique = torchvision.transforms.Compose([ 31 | groupscale, 32 | train_aug, 33 | GroupRandomGrayscale(p=0.2), 34 | ]) 35 | else: 36 | unique = torchvision.transforms.Compose([ 37 | groupscale, 38 | GroupCenterCrop(config.data.input_size)]) 39 | 40 | common = torchvision.transforms.Compose([ 41 | Stack(roll=False), 42 | ToTorchFormatTensor(div=True), 43 | normalize]) 44 | return torchvision.transforms.Compose([unique, common]) 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /utils/NCELoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | 4 | class NCELoss(nn.Module): 5 | """Loss that uses a 'hinge' on the lower bound. 6 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 7 | also smaller than that threshold. 8 | args: 9 | error_matric: What base loss to use (MSE by default). 10 | threshold: Threshold to use for the hinge. 11 | clip: Clip the loss if it is above this value. 12 | """ 13 | 14 | def __init__(self, error_metric=nn.KLDivLoss(reduction='mean')): 15 | super().__init__() 16 | print('=========using NCE Loss==========') 17 | self.error_metric = error_metric 18 | 19 | def forward(self, prediction, label): 20 | batch_size = len(prediction) 21 | probs1 = F.log_softmax(prediction, 1) 22 | probs2 = F.softmax(label * 10, 1) 23 | loss = self.error_metric(probs1, probs2) * batch_size 24 | return loss 25 | 26 | 27 | class DualLoss(nn.Module): 28 | def __init__(self, error_metric=nn.KLDivLoss(reduction='mean')): 29 | super().__init__() 30 | print('=========using DS Loss==========') 31 | self.error_metric = error_metric 32 | 33 | def forward(self, prediction, label, temp=1000): 34 | batch_size = len(prediction) 35 | prediction = prediction * F.softmax(prediction/temp, dim=0) * batch_size 36 | probs1 = F.log_softmax(prediction, 1) 37 | probs2 = F.softmax(label * 10, 1) 38 | loss = self.error_metric(probs1, probs2) * batch_size 39 | return loss -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import sys 5 | from termcolor import colored 6 | 7 | 8 | class _ColorfulFormatter(logging.Formatter): 9 | def __init__(self, *args, **kwargs): 10 | self._root_name = kwargs.pop("root_name") + "." 11 | self._abbrev_name = kwargs.pop("abbrev_name", "") 12 | if len(self._abbrev_name): 13 | self._abbrev_name = self._abbrev_name + "." 14 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 15 | 16 | def formatMessage(self, record): 17 | record.name = record.name.replace(self._root_name, self._abbrev_name) 18 | log = super(_ColorfulFormatter, self).formatMessage(record) 19 | if record.levelno == logging.WARNING: 20 | prefix = colored("WARNING", "red", attrs=["blink"]) 21 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 22 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 23 | else: 24 | return log 25 | return prefix + " " + log 26 | 27 | 28 | # so that calling setup_logger multiple times won't add many handlers 29 | @functools.lru_cache() 30 | def setup_logger( 31 | output=None, distributed_rank=0, *, color=True, name="moco", abbrev_name=None 32 | ): 33 | """ 34 | Initialize the detectron2 logger and set its verbosity level to "INFO". 35 | Args: 36 | output (str): a file name or a directory to save log. If None, will not save log file. 37 | If ends with ".txt" or ".log", assumed to be a file name. 38 | Otherwise, logs will be saved to `output/log.txt`. 39 | name (str): the root module name of this logger 40 | Returns: 41 | logging.Logger: a logger 42 | """ 43 | logger = logging.getLogger(name) 44 | logger.setLevel(logging.DEBUG) 45 | logger.propagate = False 46 | 47 | if abbrev_name is None: 48 | abbrev_name = name 49 | 50 | plain_formatter = logging.Formatter( 51 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 52 | ) 53 | # stdout logging: master only 54 | if distributed_rank == 0: 55 | ch = logging.StreamHandler(stream=sys.stdout) 56 | ch.setLevel(logging.DEBUG) 57 | if color: 58 | formatter = _ColorfulFormatter( 59 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 60 | datefmt="%m/%d %H:%M:%S", 61 | root_name=name, 62 | abbrev_name=str(abbrev_name), 63 | ) 64 | else: 65 | formatter = plain_formatter 66 | ch.setFormatter(formatter) 67 | logger.addHandler(ch) 68 | 69 | # file logging: all workers 70 | if output is not None: 71 | if output.endswith(".txt") or output.endswith(".log"): 72 | filename = output 73 | else: 74 | filename = os.path.join(output, "log.txt") 75 | if distributed_rank == 0: 76 | os.makedirs(os.path.dirname(filename), exist_ok=True) 77 | 78 | fh = logging.StreamHandler(_cached_log_stream(filename)) 79 | fh.setLevel(logging.DEBUG) 80 | fh.setFormatter(plain_formatter) 81 | logger.addHandler(fh) 82 | 83 | return logger 84 | 85 | 86 | # cache the opened file object, so that different calls to `setup_logger` 87 | # with the same file name can safely write to the same file. 88 | @functools.lru_cache(maxsize=None) 89 | def _cached_log_stream(filename): 90 | return open(filename, "a") -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | def to_tuple(x, L): 7 | if type(x) in (int, float): 8 | return [x] * L 9 | if type(x) in (list, tuple): 10 | if len(x) != L: 11 | raise ValueError('length of {} ({}) != {}'.format(x, len(x), L)) 12 | return tuple(x) 13 | raise ValueError('input {} has unkown type {}'.format(x, type(x))) 14 | 15 | 16 | class WarmupLR(_LRScheduler): 17 | 18 | def __init__(self, 19 | optimizer, 20 | warmup_epochs=0, 21 | warmup_powers=1, 22 | warmup_lrs=0, 23 | last_epoch=-1): 24 | self.num_groups = len(optimizer.param_groups) 25 | self.warmup_epochs = to_tuple(warmup_epochs, self.num_groups) 26 | self.warmup_powers = to_tuple(warmup_powers, self.num_groups) 27 | self.warmup_lrs = to_tuple(warmup_lrs, self.num_groups) 28 | super(WarmupLR, self).__init__(optimizer, last_epoch) 29 | assert self.num_groups == len(self.base_lrs) 30 | 31 | def get_lr(self): 32 | curr_lrs = [] 33 | for group_index in range(self.num_groups): 34 | if self.last_epoch < self.warmup_epochs[group_index]: 35 | progress = self.last_epoch / self.warmup_epochs[group_index] 36 | factor = progress ** self.warmup_powers[group_index] 37 | lr_gap = self.base_lrs[group_index] - self.warmup_lrs[group_index] 38 | curr_lrs.append(factor * lr_gap + self.warmup_lrs[group_index]) 39 | else: 40 | curr_lrs.append(self.get_single_lr_after_warmup(group_index)) 41 | return curr_lrs 42 | 43 | def get_single_lr_after_warmup(self, group_index): 44 | raise NotImplementedError 45 | 46 | 47 | class WarmupMultiStepLR(WarmupLR): 48 | 49 | def __init__(self, 50 | optimizer, 51 | milestones, 52 | gamma=0.1, 53 | warmup_epochs=0, 54 | warmup_powers=1, 55 | warmup_lrs=0, 56 | last_epoch=-1): 57 | 58 | if not list(milestones) == sorted(milestones): 59 | raise ValueError('Milestones should be a list of' 60 | ' increasing integers. Got %s' % repr(milestones)) 61 | self.milestones = milestones 62 | self.gamma = gamma 63 | super(WarmupMultiStepLR, self).__init__(optimizer, 64 | warmup_epochs, 65 | warmup_powers, 66 | warmup_lrs, 67 | last_epoch) 68 | if self.milestones[0] <= max(self.warmup_epochs): 69 | raise ValueError('milstones[0] ({}) <= max(warmup_epochs) ({})'.format( 70 | milestones[0], max(self.warmup_epochs))) 71 | 72 | def get_single_lr_after_warmup(self, group_index): 73 | factor = self.gamma ** bisect_right(self.milestones, self.last_epoch) 74 | return self.base_lrs[group_index] * factor 75 | 76 | 77 | class WarmupCosineAnnealingLR(WarmupLR): 78 | 79 | def __init__(self, 80 | optimizer, 81 | total_epoch, 82 | final_factor=0, 83 | warmup_epochs=0, 84 | warmup_powers=1, 85 | warmup_lrs=0, 86 | last_epoch=-1): 87 | self.total_epoch = total_epoch 88 | self.final_factor = final_factor 89 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, 90 | warmup_epochs, 91 | warmup_powers, 92 | warmup_lrs, 93 | last_epoch) 94 | 95 | def get_single_lr_after_warmup(self, group_index): 96 | warmup_epoch = self.warmup_epochs[group_index] 97 | progress = (self.last_epoch - warmup_epoch) / (self.total_epoch - warmup_epoch) 98 | progress = min(progress, 1.0) 99 | cosine_progress = (math.cos(math.pi * progress) + 1) / 2 100 | factor = cosine_progress * (1 - self.final_factor) + self.final_factor 101 | return self.base_lrs[group_index] * factor 102 | 103 | 104 | class WarmupExponentialLR(WarmupLR): 105 | 106 | def __init__(self, 107 | optimizer, 108 | total_epoch, 109 | final_factor=1e-3, 110 | warmup_epochs=0, 111 | warmup_powers=1, 112 | warmup_lrs=0, 113 | last_epoch=-1): 114 | if final_factor <= 0: 115 | raise ValueError('final_factor ({}) <= 0 not allowed'.format(final_factor)) 116 | self.total_epoch = total_epoch 117 | self.final_factor = final_factor 118 | super(WarmupExponentialLR, self).__init__(optimizer, 119 | warmup_epochs, 120 | warmup_powers, 121 | warmup_lrs, 122 | last_epoch) 123 | 124 | def get_single_lr_after_warmup(self, group_index): 125 | warmup_epoch = self.warmup_epochs[group_index] 126 | progress = (self.last_epoch - warmup_epoch) / (self.total_epoch - warmup_epoch) 127 | progress = min(progress, 1.0) 128 | factor = self.final_factor ** progress 129 | return self.base_lrs[group_index] * factor 130 | 131 | 132 | class ReduceLROnPlateau(object): 133 | """Reduce learning rate when a metric has stopped improving. 134 | Models often benefit from reducing the learning rate by a factor 135 | of 2-10 once learning stagnates. This scheduler reads a metrics 136 | quantity and if no improvement is seen for a 'patience' number 137 | of epochs, the learning rate is reduced. 138 | 139 | Args: 140 | optimizer (Optimizer): Wrapped optimizer. 141 | mode (str): One of `min`, `max`. In `min` mode, lr will 142 | be reduced when the quantity monitored has stopped 143 | decreasing; in `max` mode it will be reduced when the 144 | quantity monitored has stopped increasing. Default: 'min'. 145 | factor (float): Factor by which the learning rate will be 146 | reduced. new_lr = lr * factor. Default: 0.1. 147 | patience (int): Number of epochs with no improvement after 148 | which learning rate will be reduced. For example, if 149 | `patience = 2`, then we will ignore the first 2 epochs 150 | with no improvement, and will only decrease the LR after the 151 | 3rd epoch if the loss still hasn't improved then. 152 | Default: 10. 153 | verbose (bool): If ``True``, prints a message to stdout for 154 | each update. Default: ``False``. 155 | threshold (float): Threshold for measuring the new optimum, 156 | to only focus on significant changes. Default: 1e-4. 157 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 158 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 159 | mode or best * ( 1 - threshold ) in `min` mode. 160 | In `abs` mode, dynamic_threshold = best + threshold in 161 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 162 | cooldown (int): Number of epochs to wait before resuming 163 | normal operation after lr has been reduced. Default: 0. 164 | min_lr (float or list): A scalar or a list of scalars. A 165 | lower bound on the learning rate of all param groups 166 | or each group respectively. Default: 0. 167 | eps (float): Minimal decay applied to lr. If the difference 168 | between new and old lr is smaller than eps, the update is 169 | ignored. Default: 1e-8. 170 | 171 | Example: 172 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 173 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 174 | >>> for epoch in range(10): 175 | >>> train(...) 176 | >>> val_loss = validate(...) 177 | >>> # Note that step should be called after validate() 178 | >>> scheduler.step(val_loss) 179 | """ 180 | 181 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 182 | verbose=False, threshold=1e-4, threshold_mode='rel', 183 | cooldown=0, min_lr=0, eps=1e-8): 184 | 185 | if factor >= 1.0: 186 | raise ValueError('Factor should be < 1.0.') 187 | self.factor = factor 188 | 189 | if not isinstance(optimizer, Optimizer): 190 | raise TypeError('{} is not an Optimizer'.format( 191 | type(optimizer).__name__)) 192 | self.optimizer = optimizer 193 | 194 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 195 | if len(min_lr) != len(optimizer.param_groups): 196 | raise ValueError("expected {} min_lrs, got {}".format( 197 | len(optimizer.param_groups), len(min_lr))) 198 | self.min_lrs = list(min_lr) 199 | else: 200 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 201 | 202 | self.patience = patience 203 | self.verbose = verbose 204 | self.cooldown = cooldown 205 | self.cooldown_counter = 0 206 | self.mode = mode 207 | self.threshold = threshold 208 | self.threshold_mode = threshold_mode 209 | self.best = None 210 | self.num_bad_epochs = None 211 | self.mode_worse = None # the worse value for the chosen mode 212 | self.is_better = None 213 | self.eps = eps 214 | self.last_epoch = -1 215 | self._init_is_better(mode=mode, threshold=threshold, 216 | threshold_mode=threshold_mode) 217 | self._reset() 218 | 219 | def _reset(self): 220 | """Resets num_bad_epochs counter and cooldown counter.""" 221 | self.best = self.mode_worse 222 | self.cooldown_counter = 0 223 | self.num_bad_epochs = 0 224 | 225 | def step(self, metrics, epoch=None): 226 | current = metrics 227 | if epoch is None: 228 | epoch = self.last_epoch = self.last_epoch + 1 229 | self.last_epoch = epoch 230 | 231 | if self.is_better(current, self.best): 232 | self.best = current 233 | self.num_bad_epochs = 0 234 | else: 235 | self.num_bad_epochs += 1 236 | 237 | if self.in_cooldown: 238 | self.cooldown_counter -= 1 239 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 240 | 241 | if self.num_bad_epochs > self.patience: 242 | self._reduce_lr(epoch) 243 | self.cooldown_counter = self.cooldown 244 | self.num_bad_epochs = 0 245 | 246 | def _reduce_lr(self, epoch): 247 | for i, param_group in enumerate(self.optimizer.param_groups): 248 | old_lr = float(param_group['lr']) 249 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 250 | if old_lr - new_lr > self.eps: 251 | param_group['lr'] = new_lr 252 | if self.verbose: 253 | print('Epoch {:5d}: reducing learning rate' 254 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 255 | 256 | @property 257 | def in_cooldown(self): 258 | return self.cooldown_counter > 0 259 | 260 | def _cmp(self, mode, threshold_mode, threshold, a, best): 261 | if mode == 'min' and threshold_mode == 'rel': 262 | rel_epsilon = 1. - threshold 263 | return a < best * rel_epsilon 264 | 265 | elif mode == 'min' and threshold_mode == 'abs': 266 | return a < best - threshold 267 | 268 | elif mode == 'max' and threshold_mode == 'rel': 269 | rel_epsilon = threshold + 1. 270 | return a > best * rel_epsilon 271 | 272 | else: # mode == 'max' and epsilon_mode == 'abs': 273 | return a > best + threshold 274 | 275 | def _init_is_better(self, mode, threshold, threshold_mode): 276 | if mode not in {'min', 'max'}: 277 | raise ValueError('mode ' + mode + ' is unknown!') 278 | if threshold_mode not in {'rel', 'abs'}: 279 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') 280 | 281 | if mode == 'min': 282 | self.mode_worse = inf 283 | else: # mode == 'max': 284 | self.mode_worse = -inf 285 | 286 | self.is_better = partial(self._cmp, mode, threshold_mode, threshold) 287 | 288 | def state_dict(self): 289 | return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}} 290 | 291 | def load_state_dict(self, state_dict): 292 | self.__dict__.update(state_dict) 293 | self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) 294 | -------------------------------------------------------------------------------- /utils/solver.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from utils.lr_scheduler import WarmupMultiStepLR, WarmupCosineAnnealingLR 3 | 4 | def _optimizer(config, model, video_head): 5 | if config.solver.optim == 'adam': 6 | optimizer = optim.Adam([{'params': model.parameters()}, 7 | {'params': video_head.parameters(), 'lr': config.solver.lr}], 8 | lr=config.solver.lr * config.solver.clip_ratio, betas=(0.9, 0.999), eps=1e-8, 9 | weight_decay=0.2) # Params used from paper, the lr is smaller, more safe for fine tuning to new dataset 10 | print('Adam') 11 | elif config.solver.optim == 'sgd': 12 | 13 | optimizer = optim.SGD([{'params': model.parameters()}, 14 | {'params': video_head.parameters(), 'lr': config.solver.lr}], 15 | config.solver.lr * config.solver.clip_ratio, 16 | momentum=config.solver.momentum, 17 | weight_decay=config.solver.weight_decay) 18 | print('SGD') 19 | elif config.solver.optim == 'adamw': 20 | vision_params = [] 21 | text_params = [] 22 | for name, param in model.named_parameters(): 23 | if 'visual.' in name: 24 | vision_params.append(param) 25 | else: 26 | text_params.append(param) 27 | 28 | # print('[INFO] number of visual parameters:', len(vision_params), flush=True) 29 | # print('[INFO] number of textual parameters:', len(text_params), flush=True) 30 | optimizer = optim.AdamW([{'params': model.parameters(), 'lr': config.solver.lr * config.solver.clip_ratio}, 31 | {'params': video_head.parameters(), 'lr': config.solver.lr}], 32 | betas=(0.9, 0.999), lr=config.solver.lr, eps=1e-8, 33 | weight_decay=config.solver.weight_decay) # Params used from paper, the lr is smaller, more safe for fine tuning to new dataset 34 | # for param_group in optimizer.param_groups: 35 | # print(param_group['lr']) 36 | else: 37 | raise ValueError('Unknown optimizer: {}'.format(config.solver.optim)) 38 | return optimizer 39 | 40 | 41 | def _lr_scheduler(config, optimizer): 42 | if config.solver.type == 'cosine': 43 | lr_scheduler = WarmupCosineAnnealingLR( 44 | optimizer, 45 | config.solver.epochs, 46 | warmup_epochs=config.solver.lr_warmup_step 47 | ) 48 | elif config.solver.type == 'multistep': 49 | if isinstance(config.solver.lr_decay_step, list): 50 | milestones = config.solver.lr_decay_step 51 | elif isinstance(config.solver.lr_decay_step, int): 52 | milestones = [ 53 | config.solver.lr_decay_step * (i + 1) 54 | for i in range(config.solver.epochs // 55 | config.solver.lr_decay_step)] 56 | else: 57 | raise ValueError("error learning rate decay step: {}".format(type(config.solver.lr_decay_step))) 58 | lr_scheduler = WarmupMultiStepLR( 59 | optimizer, 60 | milestones, 61 | warmup_epochs=config.solver.lr_warmup_step 62 | ) 63 | else: 64 | raise ValueError('Unknown lr scheduler: {}'.format(config.solver.type)) 65 | return lr_scheduler 66 | 67 | 68 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utils for clip 3 | """ 4 | import os 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.distributed.nn as distnn 9 | from torch import nn 10 | import numpy 11 | 12 | def init_distributed_mode(args): 13 | """ init for distribute mode """ 14 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 15 | args.rank = int(os.environ["RANK"]) 16 | args.world_size = int(os.environ['WORLD_SIZE']) 17 | args.gpu = int(os.environ['LOCAL_RANK']) 18 | elif 'SLURM_PROCID' in os.environ: 19 | args.rank = int(os.environ['SLURM_PROCID']) 20 | args.gpu = args.rank % torch.cuda.device_count() 21 | else: 22 | print('Not using distributed mode') 23 | args.distributed = False 24 | return 25 | 26 | args.distributed = True 27 | 28 | torch.cuda.set_device(args.gpu) 29 | args.dist_backend = 'nccl' 30 | ''' 31 | This is commented due to the stupid icoding pylint checking. 32 | print('distributed init rank {}: {}'.format(args.rank, args.dist_url), flush=True) 33 | ''' 34 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 35 | world_size=args.world_size, rank=args.rank) 36 | torch.distributed.barrier() 37 | 38 | 39 | def ddp_all_reduce(*args): 40 | """ all reduce (op: sum) by ddp """ 41 | t = torch.tensor([x for x in args], dtype=torch.float64, device='cuda') 42 | dist.barrier() 43 | dist.all_reduce(t) 44 | t = t.tolist() 45 | return t 46 | 47 | 48 | def ddp_all_gather(*args): 49 | """ all gather by ddp, all gather don't have grad_fn by default """ 50 | rets = [] 51 | world_size = dist.get_world_size() 52 | for x in args: 53 | if type(x) is torch.Tensor: 54 | ret = [torch.zeros_like(x) for _ in range(world_size)] 55 | dist.barrier() 56 | dist.all_gather(ret, x) 57 | else: # for any picklable object 58 | ret = [None for _ in range(world_size)] 59 | dist.barrier() 60 | dist.all_gather_object(ret, x) 61 | rets.append(ret) 62 | return rets if len(rets) > 1 else rets[0] 63 | 64 | 65 | 66 | 67 | def gather_labels(labels): 68 | # We gather tensors from all gpus 69 | gathered_labels = ddp_all_gather(labels) 70 | all_labels = torch.cat(gathered_labels) 71 | return all_labels 72 | 73 | 74 | def gen_label_cpu(labels): 75 | num = len(labels) 76 | gt = np.zeros(shape=(num,num)) 77 | for i, label in enumerate(labels): 78 | for k in range(num): 79 | if labels[k] == label: 80 | gt[i,k] = 1 81 | return gt 82 | 83 | 84 | def gen_label(labels): 85 | num = len(labels) 86 | gt = torch.zeros(size=(num,num)) 87 | labels_column = labels.reshape(-1,1).repeat(1,num) 88 | labels_row = labels.repeat(num,1) 89 | gt[labels_column == labels_row] = 1 90 | return gt 91 | 92 | def convert_models_to_fp32(model): 93 | for p in model.parameters(): 94 | p.data = p.data.float() 95 | if p.grad is not None: 96 | p.grad.data = p.grad.data.float() 97 | 98 | def convert_models_to_fp16(model): 99 | # print(model) 100 | for p in model.parameters(): 101 | p.data = p.data.half() 102 | p.grad.data = p.grad.data.half() 103 | 104 | 105 | 106 | def gather_features( 107 | image_features, text_features, 108 | local_loss=False, gather_with_grad=False, rank=0, world_size=1): 109 | 110 | # We gather tensors from all gpus 111 | if gather_with_grad: 112 | all_image_features = torch.cat(distnn.all_gather(image_features), dim=0) 113 | all_text_features = torch.cat(distnn.all_gather(text_features), dim=0) 114 | else: 115 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 116 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 117 | dist.all_gather(gathered_image_features, image_features) 118 | dist.all_gather(gathered_text_features, text_features) 119 | if not local_loss: 120 | # ensure grads for local rank when all_* features don't have a gradient 121 | gathered_image_features[rank] = image_features 122 | gathered_text_features[rank] = text_features 123 | all_image_features = torch.cat(gathered_image_features, dim=0) 124 | all_text_features = torch.cat(gathered_text_features, dim=0) 125 | return all_image_features, all_text_features 126 | 127 | 128 | 129 | def create_logits(image_features, text_features, logit_scale, local_loss=False): 130 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 131 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 132 | if dist.get_world_size() > 1: 133 | all_image_features, all_text_features = gather_features( 134 | image_features, text_features, 135 | local_loss=local_loss, gather_with_grad=False, 136 | rank=dist.get_rank(), world_size=dist.get_world_size()) 137 | 138 | # cosine similarity as logits 139 | if local_loss: 140 | logits_per_image = logit_scale * image_features @ all_text_features.T 141 | logits_per_text = logit_scale * text_features @ all_image_features.T 142 | else: 143 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 144 | logits_per_text = logits_per_image.T 145 | 146 | else: 147 | logits_per_image = logit_scale * image_features @ text_features.T 148 | logits_per_text = logit_scale * text_features @ image_features.T 149 | 150 | # shape = [global_batch_size, global_batch_size] 151 | return logits_per_image, logits_per_text 152 | 153 | 154 | 155 | 156 | 157 | def epoch_saving(epoch, model, video_head, optimizer, filename): 158 | torch.save({ 159 | 'epoch': epoch, 160 | 'model_state_dict': model.state_dict(), 161 | 'fusion_model_state_dict': video_head.state_dict(), 162 | 'optimizer_state_dict': optimizer.state_dict(), 163 | }, filename) #just change to your preferred folder/filename 164 | 165 | def best_saving(working_dir, epoch, model, video_head, optimizer): 166 | best_name = '{}/model_best.pt'.format(working_dir) 167 | torch.save({ 168 | 'epoch': epoch, 169 | 'model_state_dict': model.state_dict(), 170 | 'fusion_model_state_dict': video_head.state_dict(), 171 | 'optimizer_state_dict': optimizer.state_dict(), 172 | }, best_name) # just change to your preferred folder/filename 173 | 174 | 175 | def reduce_tensor(tensor, n=None): 176 | if n is None: 177 | n = dist.get_world_size() 178 | rt = tensor.clone() 179 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 180 | rt = rt / n 181 | return rt 182 | 183 | 184 | class AverageMeter: 185 | """Computes and stores the average and current value""" 186 | def __init__(self): 187 | self.reset() 188 | 189 | def reset(self): 190 | self.val = 0 191 | self.avg = 0 192 | self.sum = 0 193 | self.count = 0 194 | 195 | def update(self, val, n=1): 196 | self.val = val 197 | self.sum += val * n 198 | self.count += n 199 | self.avg = self.sum / self.count 200 | 201 | def sync(self): 202 | rank = dist.get_rank() 203 | world_size = dist.get_world_size() 204 | val = torch.tensor(self.val).cuda() 205 | sum_v = torch.tensor(self.sum).cuda() 206 | count = torch.tensor(self.count).cuda() 207 | self.val = reduce_tensor(val, world_size).item() 208 | self.sum = reduce_tensor(sum_v, 1).item() 209 | self.count = reduce_tensor(count, 1).item() 210 | self.avg = self.sum / self.count 211 | 212 | 213 | def accuracy(output, target, topk=(1, )): 214 | """Computes the precision@k for the specified values of k""" 215 | maxk = max(topk) 216 | batch_size = target.size(0) 217 | 218 | _, pred = output.topk(maxk, 1, True, True) 219 | pred = pred.t() 220 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 221 | res = [] 222 | for k in topk: 223 | correct_k = correct[:k].reshape(-1).float().sum(0) 224 | res.append(correct_k.mul_(100.0 / batch_size)) 225 | return res 226 | 227 | 228 | from torchnet import meter 229 | def mean_average_precision(probs, labels): 230 | """Computes MAP for ActivityNet evaluation""" 231 | if not isinstance(probs, torch.Tensor): 232 | probs = torch.Tensor(probs).cuda() 233 | 234 | if not isinstance(labels, torch.Tensor): 235 | labels = torch.tensor(labels).long().cuda() 236 | 237 | gt = torch.zeros_like(probs).int() 238 | acc_meter = meter.ClassErrorMeter(topk=[1, 3], accuracy=True) 239 | gt[torch.LongTensor(range(gt.size(0))), labels] = 1 240 | acc_meter.add(probs, labels) 241 | acc = acc_meter.value() 242 | 243 | probs = torch.nn.functional.softmax(probs, dim=1) 244 | 245 | map_meter = meter.mAPMeter() 246 | map_meter.add(probs, gt) 247 | ap = map_meter.value() 248 | ap = float(ap) * 100 249 | return [torch.tensor(acc[0]).cuda(), torch.tensor(ap).cuda()] 250 | 251 | 252 | if __name__=='__main__': 253 | probs = torch.load('ANet_similarity_336.pth') # similarity logits 254 | labels = torch.load('ANet_labels_336.pth') # class ids 255 | 256 | mAP = mean_average_precision(probs, labels) 257 | print(mAP) 258 | 259 | --------------------------------------------------------------------------------